# Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. # https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Contains well known classes. This files defines well known classes which need extra maintenance including: - Any - Duration - FieldMask - Struct - Timestamp """ __author__ = 'jieluo@google.com (Jie Luo)' from datetime import datetime from datetime import timedelta import six from google.protobuf.descriptor import FieldDescriptor _TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' _NANOS_PER_SECOND = 1000000000 _NANOS_PER_MILLISECOND = 1000000 _NANOS_PER_MICROSECOND = 1000 _MILLIS_PER_SECOND = 1000 _MICROS_PER_SECOND = 1000000 _SECONDS_PER_DAY = 24 * 3600 class Error(Exception): """Top-level module error.""" class ParseError(Error): """Thrown in case of parsing error.""" class Any(object): """Class for Any Message type.""" def Pack(self, msg, type_url_prefix='type.googleapis.com/'): """Packs the specified message into current Any message.""" if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) else: self.type_url = '%s%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) self.value = msg.SerializeToString() def Unpack(self, msg): """Unpacks the current Any message into specified message.""" descriptor = msg.DESCRIPTOR if not self.Is(descriptor): return False msg.ParseFromString(self.value) return True def Is(self, descriptor): """Checks if this Any represents the given protobuf type.""" # Only last part is to be used: b/25630112 return self.type_url.split('/')[-1] == descriptor.full_name class Timestamp(object): """Class for Timestamp message type.""" def ToJsonString(self): """Converts Timestamp to RFC 3339 date string format. Returns: A string converted from timestamp. The string is always Z-normalized and uses 3, 6 or 9 fractional digits as required to represent the exact time. Example of the return format: '1972-01-01T10:00:20.021Z' """ nanos = self.nanos % _NANOS_PER_SECOND total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND seconds = total_sec % _SECONDS_PER_DAY days = (total_sec - seconds) // _SECONDS_PER_DAY dt = datetime(1970, 1, 1) + timedelta(days, seconds) result = dt.isoformat() if (nanos % 1e9) == 0: # If there are 0 fractional digits, the fractional # point '.' should be omitted when serializing. return result + 'Z' if (nanos % 1e6) == 0: # Serialize 3 fractional digits. return result + '.%03dZ' % (nanos / 1e6) if (nanos % 1e3) == 0: # Serialize 6 fractional digits. return result + '.%06dZ' % (nanos / 1e3) # Serialize 9 fractional digits. return result + '.%09dZ' % nanos def FromJsonString(self, value): """Parse a RFC 3339 date string format to Timestamp. Args: value: A date string. Any fractional digits (or none) and any offset are accepted as long as they fit into nano-seconds precision. Example of accepted format: '1972-01-01T10:00:20.021-05:00' Raises: ParseError: On parsing problems. """ timezone_offset = value.find('Z') if timezone_offset == -1: timezone_offset = value.find('+') if timezone_offset == -1: timezone_offset = value.rfind('-') if timezone_offset == -1: raise ParseError( 'Failed to parse timestamp: missing valid timezone offset.') time_value = value[0:timezone_offset] # Parse datetime and nanos. point_position = time_value.find('.') if point_position == -1: second_value = time_value nano_value = '' else: second_value = time_value[:point_position] nano_value = time_value[point_position + 1:] date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) td = date_object - datetime(1970, 1, 1) seconds = td.seconds + td.days * _SECONDS_PER_DAY if len(nano_value) > 9: raise ParseError( 'Failed to parse Timestamp: nanos {0} more than ' '9 fractional digits.'.format(nano_value)) if nano_value: nanos = round(float('0.' + nano_value) * 1e9) else: nanos = 0 # Parse timezone offsets. if value[timezone_offset] == 'Z': if len(value) != timezone_offset + 1: raise ParseError('Failed to parse timestamp: invalid trailing' ' data {0}.'.format(value)) else: timezone = value[timezone_offset:] pos = timezone.find(':') if pos == -1: raise ParseError( 'Invalid timezone offset value: {0}.'.format(timezone)) if timezone[0] == '+': seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 else: seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 # Set seconds and nanos self.seconds = int(seconds) self.nanos = int(nanos) def GetCurrentTime(self): """Get the current UTC into Timestamp.""" self.FromDatetime(datetime.utcnow()) def ToNanoseconds(self): """Converts Timestamp to nanoseconds since epoch.""" return self.seconds * _NANOS_PER_SECOND + self.nanos def ToMicroseconds(self): """Converts Timestamp to microseconds since epoch.""" return (self.seconds * _MICROS_PER_SECOND + self.nanos // _NANOS_PER_MICROSECOND) def ToMilliseconds(self): """Converts Timestamp to milliseconds since epoch.""" return (self.seconds * _MILLIS_PER_SECOND + self.nanos // _NANOS_PER_MILLISECOND) def ToSeconds(self): """Converts Timestamp to seconds since epoch.""" return self.seconds def FromNanoseconds(self, nanos): """Converts nanoseconds since epoch to Timestamp.""" self.seconds = nanos // _NANOS_PER_SECOND self.nanos = nanos % _NANOS_PER_SECOND def FromMicroseconds(self, micros): """Converts microseconds since epoch to Timestamp.""" self.seconds = micros // _MICROS_PER_SECOND self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND def FromMilliseconds(self, millis): """Converts milliseconds since epoch to Timestamp.""" self.seconds = millis // _MILLIS_PER_SECOND self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND def FromSeconds(self, seconds): """Converts seconds since epoch to Timestamp.""" self.seconds = seconds self.nanos = 0 def ToDatetime(self): """Converts Timestamp to datetime.""" return datetime.utcfromtimestamp( self.seconds + self.nanos / float(_NANOS_PER_SECOND)) def FromDatetime(self, dt): """Converts datetime to Timestamp.""" td = dt - datetime(1970, 1, 1) self.seconds = td.seconds + td.days * _SECONDS_PER_DAY self.nanos = td.microseconds * _NANOS_PER_MICROSECOND class Duration(object): """Class for Duration message type.""" def ToJsonString(self): """Converts Duration to string format. Returns: A string converted from self. The string format will contains 3, 6, or 9 fractional digits depending on the precision required to represent the exact Duration value. For example: "1s", "1.010s", "1.000000100s", "-3.100s" """ if self.seconds < 0 or self.nanos < 0: result = '-' seconds = - self.seconds + int((0 - self.nanos) // 1e9) nanos = (0 - self.nanos) % 1e9 else: result = '' seconds = self.seconds + int(self.nanos // 1e9) nanos = self.nanos % 1e9 result += '%d' % seconds if (nanos % 1e9) == 0: # If there are 0 fractional digits, the fractional # point '.' should be omitted when serializing. return result + 's' if (nanos % 1e6) == 0: # Serialize 3 fractional digits. return result + '.%03ds' % (nanos / 1e6) if (nanos % 1e3) == 0: # Serialize 6 fractional digits. return result + '.%06ds' % (nanos / 1e3) # Serialize 9 fractional digits. return result + '.%09ds' % nanos def FromJsonString(self, value): """Converts a string to Duration. Args: value: A string to be converted. The string must end with 's'. Any fractional digits (or none) are accepted as long as they fit into precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s Raises: ParseError: On parsing problems. """ if len(value) < 1 or value[-1] != 's': raise ParseError( 'Duration must end with letter "s": {0}.'.format(value)) try: pos = value.find('.') if pos == -1: self.seconds = int(value[:-1]) self.nanos = 0 else: self.seconds = int(value[:pos]) if value[0] == '-': self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) else: self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) except ValueError: raise ParseError( 'Couldn\'t parse duration: {0}.'.format(value)) def ToNanoseconds(self): """Converts a Duration to nanoseconds.""" return self.seconds * _NANOS_PER_SECOND + self.nanos def ToMicroseconds(self): """Converts a Duration to microseconds.""" micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND) return self.seconds * _MICROS_PER_SECOND + micros def ToMilliseconds(self): """Converts a Duration to milliseconds.""" millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND) return self.seconds * _MILLIS_PER_SECOND + millis def ToSeconds(self): """Converts a Duration to seconds.""" return self.seconds def FromNanoseconds(self, nanos): """Converts nanoseconds to Duration.""" self._NormalizeDuration(nanos // _NANOS_PER_SECOND, nanos % _NANOS_PER_SECOND) def FromMicroseconds(self, micros): """Converts microseconds to Duration.""" self._NormalizeDuration( micros // _MICROS_PER_SECOND, (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) def FromMilliseconds(self, millis): """Converts milliseconds to Duration.""" self._NormalizeDuration( millis // _MILLIS_PER_SECOND, (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) def FromSeconds(self, seconds): """Converts seconds to Duration.""" self.seconds = seconds self.nanos = 0 def ToTimedelta(self): """Converts Duration to timedelta.""" return timedelta( seconds=self.seconds, microseconds=_RoundTowardZero( self.nanos, _NANOS_PER_MICROSECOND)) def FromTimedelta(self, td): """Convertd timedelta to Duration.""" self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, td.microseconds * _NANOS_PER_MICROSECOND) def _NormalizeDuration(self, seconds, nanos): """Set Duration by seconds and nonas.""" # Force nanos to be negative if the duration is negative. if seconds < 0 and nanos > 0: seconds += 1 nanos -= _NANOS_PER_SECOND self.seconds = seconds self.nanos = nanos def _RoundTowardZero(value, divider): """Truncates the remainder part after division.""" # For some languanges, the sign of the remainder is implementation # dependent if any of the operands is negative. Here we enforce # "rounded toward zero" semantics. For example, for (-5) / 2 an # implementation may give -3 as the result with the remainder being # 1. This function ensures we always return -2 (closer to zero). result = value // divider remainder = value % divider if result < 0 and remainder > 0: return result + 1 else: return result class FieldMask(object): """Class for FieldMask message type.""" def ToJsonString(self): """Converts FieldMask to string according to proto3 JSON spec.""" return ','.join(self.paths) def FromJsonString(self, value): """Converts string to FieldMask according to proto3 JSON spec.""" self.Clear() for path in value.split(','): self.paths.append(path) def IsValidForDescriptor(self, message_descriptor): """Checks whether the FieldMask is valid for Message Descriptor.""" for path in self.paths: if not _IsValidPath(message_descriptor, path): return False return True def AllFieldsFromDescriptor(self, message_descriptor): """Gets all direct fields of Message Descriptor to FieldMask.""" self.Clear() for field in message_descriptor.fields: self.paths.append(field.name) def CanonicalFormFromMask(self, mask): """Converts a FieldMask to the canonical form. Removes paths that are covered by another path. For example, "foo.bar" is covered by "foo" and will be removed if "foo" is also in the FieldMask. Then sorts all paths in alphabetical order. Args: mask: The original FieldMask to be converted. """ tree = _FieldMaskTree(mask) tree.ToFieldMask(self) def Union(self, mask1, mask2): """Merges mask1 and mask2 into this FieldMask.""" _CheckFieldMaskMessage(mask1) _CheckFieldMaskMessage(mask2) tree = _FieldMaskTree(mask1) tree.MergeFromFieldMask(mask2) tree.ToFieldMask(self) def Intersect(self, mask1, mask2): """Intersects mask1 and mask2 into this FieldMask.""" _CheckFieldMaskMessage(mask1) _CheckFieldMaskMessage(mask2) tree = _FieldMaskTree(mask1) intersection = _FieldMaskTree() for path in mask2.paths: tree.IntersectPath(path, intersection) intersection.ToFieldMask(self) def MergeMessage( self, source, destination, replace_message_field=False, replace_repeated_field=False): """Merges fields specified in FieldMask from source to destination. Args: source: Source message. destination: The destination message to be merged into. replace_message_field: Replace message field if True. Merge message field if False. replace_repeated_field: Replace repeated field if True. Append elements of repeated field if False. """ tree = _FieldMaskTree(self) tree.MergeMessage( source, destination, replace_message_field, replace_repeated_field) def _IsValidPath(message_descriptor, path): """Checks whether the path is valid for Message Descriptor.""" parts = path.split('.') last = parts.pop() for name in parts: field = message_descriptor.fields_by_name[name] if (field is None or field.label == FieldDescriptor.LABEL_REPEATED or field.type != FieldDescriptor.TYPE_MESSAGE): return False message_descriptor = field.message_type return last in message_descriptor.fields_by_name def _CheckFieldMaskMessage(message): """Raises ValueError if message is not a FieldMask.""" message_descriptor = message.DESCRIPTOR if (message_descriptor.name != 'FieldMask' or message_descriptor.file.name != 'google/protobuf/field_mask.proto'): raise ValueError('Message {0} is not a FieldMask.'.format( message_descriptor.full_name)) class _FieldMaskTree(object): """Represents a FieldMask in a tree structure. For example, given a FieldMask "foo.bar,foo.baz,bar.baz", the FieldMaskTree will be: [_root] -+- foo -+- bar | | | +- baz | +- bar --- baz In the tree, each leaf node represents a field path. """ def __init__(self, field_mask=None): """Initializes the tree by FieldMask.""" self._root = {} if field_mask: self.MergeFromFieldMask(field_mask) def MergeFromFieldMask(self, field_mask): """Merges a FieldMask to the tree.""" for path in field_mask.paths: self.AddPath(path) def AddPath(self, path): """Adds a field path into the tree. If the field path to add is a sub-path of an existing field path in the tree (i.e., a leaf node), it means the tree already matches the given path so nothing will be added to the tree. If the path matches an existing non-leaf node in the tree, that non-leaf node will be turned into a leaf node with all its children removed because the path matches all the node's children. Otherwise, a new path will be added. Args: path: The field path to add. """ node = self._root for name in path.split('.'): if name not in node: node[name] = {} elif not node[name]: # Pre-existing empty node implies we already have this entire tree. return node = node[name] # Remove any sub-trees we might have had. node.clear() def ToFieldMask(self, field_mask): """Converts the tree to a FieldMask.""" field_mask.Clear() _AddFieldPaths(self._root, '', field_mask) def IntersectPath(self, path, intersection): """Calculates the intersection part of a field path with this tree. Args: path: The field path to calculates. intersection: The out tree to record the intersection part. """ node = self._root for name in path.split('.'): if name not in node: return elif not node[name]: intersection.AddPath(path) return node = node[name] intersection.AddLeafNodes(path, node) def AddLeafNodes(self, prefix, node): """Adds leaf nodes begin with prefix to this tree.""" if not node: self.AddPath(prefix) for name in node: child_path = prefix + '.' + name self.AddLeafNodes(child_path, node[name]) def MergeMessage( self, source, destination, replace_message, replace_repeated): """Merge all fields specified by this tree from source to destination.""" _MergeMessage( self._root, source, destination, replace_message, replace_repeated) def _StrConvert(value): """Converts value to str if it is not.""" # This file is imported by c extension and some methods like ClearField # requires string for the field name. py2/py3 has different text # type and may use unicode. if not isinstance(value, str): return value.encode('utf-8') return value def _MergeMessage( node, source, destination, replace_message, replace_repeated): """Merge all fields specified by a sub-tree from source to destination.""" source_descriptor = source.DESCRIPTOR for name in node: child = node[name] field = source_descriptor.fields_by_name[name] if field is None: raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( name, source_descriptor.full_name)) if child: # Sub-paths are only allowed for singular message fields. if (field.label == FieldDescriptor.LABEL_REPEATED or field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): raise ValueError('Error: Field {0} in message {1} is not a singular ' 'message field and cannot have sub-fields.'.format( name, source_descriptor.full_name)) _MergeMessage( child, getattr(source, name), getattr(destination, name), replace_message, replace_repeated) continue if field.label == FieldDescriptor.LABEL_REPEATED: if replace_repeated: destination.ClearField(_StrConvert(name)) repeated_source = getattr(source, name) repeated_destination = getattr(destination, name) if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: for item in repeated_source: repeated_destination.add().MergeFrom(item) else: repeated_destination.extend(repeated_source) else: if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: if replace_message: destination.ClearField(_StrConvert(name)) if source.HasField(name): getattr(destination, name).MergeFrom(getattr(source, name)) else: setattr(destination, name, getattr(source, name)) def _AddFieldPaths(node, prefix, field_mask): """Adds the field paths descended from node to field_mask.""" if not node: field_mask.paths.append(prefix) return for name in sorted(node): if prefix: child_path = prefix + '.' + name else: child_path = name _AddFieldPaths(node[name], child_path, field_mask) _INT_OR_FLOAT = six.integer_types + (float,) def _SetStructValue(struct_value, value): if value is None: struct_value.null_value = 0 elif isinstance(value, bool): # Note: this check must come before the number check because in Python # True and False are also considered numbers. struct_value.bool_value = value elif isinstance(value, six.string_types): struct_value.string_value = value elif isinstance(value, _INT_OR_FLOAT): struct_value.number_value = value else: raise ValueError('Unexpected type') def _GetStructValue(struct_value): which = struct_value.WhichOneof('kind') if which == 'struct_value': return struct_value.struct_value elif which == 'null_value': return None elif which == 'number_value': return struct_value.number_value elif which == 'string_value': return struct_value.string_value elif which == 'bool_value': return struct_value.bool_value elif which == 'list_value': return struct_value.list_value elif which is None: raise ValueError('Value not set') class Struct(object): """Class for Struct message type.""" __slots__ = [] def __getitem__(self, key): return _GetStructValue(self.fields[key]) def __setitem__(self, key, value): _SetStructValue(self.fields[key], value) def get_or_create_list(self, key): """Returns a list for this key, creating if it didn't exist already.""" return self.fields[key].list_value def get_or_create_struct(self, key): """Returns a struct for this key, creating if it didn't exist already.""" return self.fields[key].struct_value # TODO(haberman): allow constructing/merging from dict. class ListValue(object): """Class for ListValue message type.""" def __len__(self): return len(self.values) def append(self, value): _SetStructValue(self.values.add(), value) def extend(self, elem_seq): for value in elem_seq: self.append(value) def __getitem__(self, index): """Retrieves item by the specified index.""" return _GetStructValue(self.values.__getitem__(index)) def __setitem__(self, index, value): _SetStructValue(self.values.__getitem__(index), value) def items(self): for i in range(len(self)): yield self[i] def add_struct(self): """Appends and returns a struct value as the next value in the list.""" return self.values.add().struct_value def add_list(self): """Appends and returns a list value as the next value in the list.""" return self.values.add().list_value WKTBASES = { 'google.protobuf.Any': Any, 'google.protobuf.Duration': Duration, 'google.protobuf.FieldMask': FieldMask, 'google.protobuf.ListValue': ListValue, 'google.protobuf.Struct': Struct, 'google.protobuf.Timestamp': Timestamp, }