aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/text_format.py
diff options
context:
space:
mode:
authorFeng Xiao <xfxyjwf@gmail.com>2018-08-08 17:00:41 -0700
committerFeng Xiao <xfxyjwf@gmail.com>2018-08-08 17:00:41 -0700
commit6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3 (patch)
treee575738adf52d24b883cca5e8928a5ded31caba1 /python/google/protobuf/text_format.py
parente7746f487cb9cca685ffb1b3d7dccc5554b618a4 (diff)
downloadprotobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.tar.gz
protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.tar.bz2
protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.zip
Down-integrate from google3.
Diffstat (limited to 'python/google/protobuf/text_format.py')
-rwxr-xr-xpython/google/protobuf/text_format.py134
1 files changed, 88 insertions, 46 deletions
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 2cbd21bc..5dd41830 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -55,15 +55,15 @@ from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import text_encoding
-__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue',
- 'Merge']
+__all__ = ['MessageToString', 'Parse', 'PrintMessage', 'PrintField',
+ 'PrintFieldValue', 'Merge', 'MessageToBytes']
_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(),
type_checkers.Int32ValueChecker(),
type_checkers.Uint64ValueChecker(),
type_checkers.Int64ValueChecker())
-_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE)
-_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
+_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE)
+_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE)
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
_QUOTES = frozenset(("'", '"'))
@@ -121,6 +121,7 @@ class TextWriter(object):
def MessageToString(message,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
@@ -128,6 +129,7 @@ def MessageToString(message,
descriptor_pool=None,
indent=0,
message_formatter=None):
+ # type: (...) -> str
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -137,8 +139,11 @@ def MessageToString(message,
Args:
message: The protocol buffers message.
- as_utf8: Produce text output in UTF8 format.
+ as_utf8: Return unescaped Unicode for non-ASCII characters.
+ In Python 3 actual Unicode characters may appear as is in strings.
+ In Python 2 the return value will be valid UTF-8 rather than only ASCII.
as_one_line: Don't introduce newlines between fields.
+ use_short_repeated_primitives: Use short repeated format for primitives.
pointy_brackets: If True, use angle brackets instead of curly braces for
nesting.
use_index_order: If True, fields of a proto message will be printed using
@@ -159,7 +164,8 @@ def MessageToString(message,
A string of the text formatted protocol buffer message.
"""
out = TextWriter(as_utf8)
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, use_field_number,
descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -170,6 +176,16 @@ def MessageToString(message,
return result
+def MessageToBytes(message, **kwargs):
+ # type: (...) -> bytes
+ """Convert protobuf message to encoded text format. See MessageToString."""
+ text = MessageToString(message, **kwargs)
+ if isinstance(text, bytes):
+ return text
+ codec = 'utf-8' if kwargs.get('as_utf8') else 'ascii'
+ return text.encode(codec)
+
+
def _IsMapEntry(field):
return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
field.message_type.has_options and
@@ -181,13 +197,15 @@ def PrintMessage(message,
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
use_field_number=False,
descriptor_pool=None,
message_formatter=None):
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, use_field_number,
descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -199,12 +217,14 @@ def PrintField(field,
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
message_formatter=None):
"""Print a single field name/value pair."""
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, message_formatter)
printer.PrintField(field, value)
@@ -215,12 +235,14 @@ def PrintFieldValue(field,
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
message_formatter=None):
"""Print a single field value (not including name)."""
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, message_formatter)
printer.PrintFieldValue(field, value)
@@ -258,6 +280,7 @@ class _Printer(object):
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
@@ -274,8 +297,11 @@ class _Printer(object):
Args:
out: To record the text format result.
indent: The indent level for pretty print.
- as_utf8: Produce text output in UTF8 format.
+ as_utf8: Return unescaped Unicode for non-ASCII characters.
+ In Python 3 actual Unicode characters may appear as is in strings.
+ In Python 2 the return value will be valid UTF-8 rather than ASCII.
as_one_line: Don't introduce newlines between fields.
+ use_short_repeated_primitives: Use short repeated format for primitives.
pointy_brackets: If True, use angle brackets instead of curly braces for
nesting.
use_index_order: If True, print fields of a proto message using the order
@@ -294,6 +320,7 @@ class _Printer(object):
self.indent = indent
self.as_utf8 = as_utf8
self.as_one_line = as_one_line
+ self.use_short_repeated_primitives = use_short_repeated_primitives
self.pointy_brackets = pointy_brackets
self.use_index_order = use_index_order
self.float_format = float_format
@@ -351,13 +378,18 @@ class _Printer(object):
entry_submsg = value.GetEntryClass()(key=key, value=value[key])
self.PrintField(field, entry_submsg)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- for element in value:
- self.PrintField(field, element)
+ if (self.use_short_repeated_primitives
+ and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE
+ and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING):
+ self._PrintShortRepeatedPrimitivesValue(field, value)
+ else:
+ for element in value:
+ self.PrintField(field, element)
else:
self.PrintField(field, value)
- def PrintField(self, field, value):
- """Print a single field name/value pair."""
+ def _PrintFieldName(self, field):
+ """Print field name."""
out = self.out
out.write(' ' * self.indent)
if self.use_field_number:
@@ -383,11 +415,22 @@ class _Printer(object):
# don't include it.
out.write(': ')
+ def PrintField(self, field, value):
+ """Print a single field name/value pair."""
+ self._PrintFieldName(field)
self.PrintFieldValue(field, value)
- if self.as_one_line:
- out.write(' ')
- else:
- out.write('\n')
+ self.out.write(' ' if self.as_one_line else '\n')
+
+ def _PrintShortRepeatedPrimitivesValue(self, field, value):
+ # Note: this is called only when value has at least one element.
+ self._PrintFieldName(field)
+ self.out.write('[')
+ for i in xrange(len(value) - 1):
+ self.PrintFieldValue(field, value[i])
+ self.out.write(', ')
+ self.PrintFieldValue(field, value[-1])
+ self.out.write(']')
+ self.out.write(' ' if self.as_one_line else '\n')
def _PrintMessageFieldValue(self, value):
if self.pointy_brackets:
@@ -428,12 +471,12 @@ class _Printer(object):
out.write(str(value))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
out.write('\"')
- if isinstance(value, six.text_type):
+ if isinstance(value, six.text_type) and (six.PY2 or not self.as_utf8):
out_value = value.encode('utf-8')
else:
out_value = value
if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- # We need to escape non-UTF8 chars in TYPE_BYTES field.
+ # We always need to escape all binary data in TYPE_BYTES fields.
out_as_utf8 = False
else:
out_as_utf8 = self.as_utf8
@@ -487,12 +530,7 @@ def Parse(text,
Raises:
ParseError: On text parsing problems.
"""
- if not isinstance(text, str):
- if six.PY3:
- text = text.decode('utf-8')
- else:
- text = text.encode('utf-8')
- return ParseLines(text.split('\n'),
+ return ParseLines(text.split(b'\n' if isinstance(text, bytes) else u'\n'),
message,
allow_unknown_extension,
allow_field_number,
@@ -523,13 +561,8 @@ def Merge(text,
Raises:
ParseError: On text parsing problems.
"""
- if not isinstance(text, str):
- if six.PY3:
- text = text.decode('utf-8')
- else:
- text = text.encode('utf-8')
return MergeLines(
- text.split('\n'),
+ text.split(b'\n' if isinstance(text, bytes) else u'\n'),
message,
allow_unknown_extension,
allow_field_number,
@@ -570,6 +603,9 @@ def MergeLines(lines,
descriptor_pool=None):
"""Parses a text representation of a protocol message into a message.
+ Like ParseLines(), but allows repeated values for a non-repeated field, and
+ uses the last one.
+
Args:
lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
@@ -601,22 +637,12 @@ class _Parser(object):
self.allow_field_number = allow_field_number
self.descriptor_pool = descriptor_pool
- def ParseFromString(self, text, message):
- """Parses a text representation of a protocol message into a message."""
- if not isinstance(text, str):
- text = text.decode('utf-8')
- return self.ParseLines(text.split('\n'), message)
-
def ParseLines(self, lines, message):
"""Parses a text representation of a protocol message into a message."""
self._allow_multiple_scalars = False
self._ParseOrMerge(lines, message)
return message
- def MergeFromString(self, text, message):
- """Merges a text representation of a protocol message into a message."""
- return self._MergeLines(text.split('\n'), message)
-
def MergeLines(self, lines, message):
"""Merges a text representation of a protocol message into a message."""
self._allow_multiple_scalars = True
@@ -633,7 +659,14 @@ class _Parser(object):
Raises:
ParseError: On text parsing problems.
"""
- tokenizer = Tokenizer(lines)
+ # Tokenize expects native str lines.
+ if six.PY2:
+ str_lines = (line if isinstance(line, str) else line.encode('utf-8')
+ for line in lines)
+ else:
+ str_lines = (line if isinstance(line, str) else line.decode('utf-8')
+ for line in lines)
+ tokenizer = Tokenizer(str_lines)
while not tokenizer.AtEnd():
self._MergeField(tokenizer, message)
@@ -1019,7 +1052,9 @@ class Tokenizer(object):
r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number
] + [ # quoted str for each quote mark
- r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
+ # Avoid backtracking! https://stackoverflow.com/a/844267
+ r'{qt}[^{qt}\n\\]*((\\.)+[^{qt}\n\\]*)*({qt}|\\?$)'.format(qt=mark)
+ for mark in _QUOTES
]))
_IDENTIFIER = re.compile(r'[^\d\W]\w*')
@@ -1316,7 +1351,8 @@ class Tokenizer(object):
def ParseError(self, message):
"""Creates and *returns* a ParseError for the current token."""
- return ParseError(message, self._line + 1, self._column + 1)
+ return ParseError('\'' + self._current_line + '\': ' + message,
+ self._line + 1, self._column + 1)
def _StringParseError(self, e):
return self.ParseError('Couldn\'t parse string: ' + str(e))
@@ -1490,6 +1526,12 @@ def _ParseAbstractInteger(text, is_long=False):
ValueError: Thrown Iff the text is not a valid integer.
"""
# Do the actual parsing. Exception handling is propagated to caller.
+ orig_text = text
+ c_octal_match = re.match(r'(-?)0(\d+)$', text)
+ if c_octal_match:
+ # Python 3 no longer supports 0755 octal syntax without the 'o', so
+ # we always use the '0o' prefix for multi-digit numbers starting with 0.
+ text = c_octal_match.group(1) + '0o' + c_octal_match.group(2)
try:
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
@@ -1499,7 +1541,7 @@ def _ParseAbstractInteger(text, is_long=False):
else:
return int(text, 0)
except ValueError:
- raise ValueError('Couldn\'t parse integer: %s' % text)
+ raise ValueError('Couldn\'t parse integer: %s' % orig_text)
def ParseFloat(text):