diff options
Diffstat (limited to 'python/google/protobuf/reflection.py')
-rwxr-xr-x | python/google/protobuf/reflection.py | 295 |
1 files changed, 94 insertions, 201 deletions
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 75202c4e..ef054466 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -43,6 +43,7 @@ import weakref from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod @@ -261,8 +262,8 @@ def _DefaultValueForField(message, field): # been set. (Depends on order in which we initialize the classes). return _RepeatedCompositeFieldContainer(listener, field.message_type) else: - return _RepeatedScalarFieldContainer(listener, - _VALUE_CHECKERS[field.cpp_type]) + return _RepeatedScalarFieldContainer( + listener, type_checkers.VALUE_CHECKERS[field.cpp_type]) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: assert field.default_value is None @@ -370,7 +371,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): python_field_name = _ValueFieldName(proto_field_name) has_field_name = _HasFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) - type_checker = _VALUE_CHECKERS[field.cpp_type] + type_checker = type_checkers.VALUE_CHECKERS[field.cpp_type] def getter(self): return getattr(self, python_field_name) @@ -614,7 +615,7 @@ def _BytesForNonRepeatedElement(value, field_number, field_type): within FieldDescriptor. """ try: - fn = _TYPE_TO_BYTE_SIZE_FN[field_type] + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] return fn(field_number, value) except KeyError: raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) @@ -707,7 +708,7 @@ def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): return try: - method = _TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] + method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] method(encoder, field_number, value) except KeyError: raise message_mod.EncodeError('Unrecognized field type: %d' % @@ -748,15 +749,24 @@ def _ImergeSorted(*streams): def _AddSerializeToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - Encoder = encoder.Encoder def SerializeToString(self): + # Check if the message has all of its required fields set. + errors = [] + if not _InternalIsInitialized(self, errors): + raise message_mod.EncodeError('\n'.join(errors)) + return self.SerializePartialToString() + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Encoder = encoder.Encoder + + def SerializePartialToString(self): encoder = Encoder() # We need to serialize all extension and non-extension fields # together, in sorted order by field number. - - # Step 3: Iterate over all extension and non-extension fields, sorted - # in order of tag number, and serialize each one to the wire. for field_descriptor, field_value in self.ListFields(): if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: repeated_value = field_value @@ -766,13 +776,13 @@ def _AddSerializeToStringMethod(message_descriptor, cls): _SerializeValueToEncoder(element, field_descriptor.number, field_descriptor, encoder) return encoder.ToString() - cls.SerializeToString = SerializeToString + cls.SerializePartialToString = SerializePartialToString def _WireTypeForFieldType(field_type): """Given a field type, returns the expected wire type.""" try: - return _FIELD_TYPE_TO_WIRE_TYPE[field_type] + return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type] except KeyError: raise message_mod.DecodeError('Unknown field type: %d' % field_type) @@ -804,7 +814,7 @@ def _DeserializeScalarFromDecoder(field_type, decoder): be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. """ try: - method = _TYPE_TO_DESERIALIZE_METHOD[field_type] + method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type] return method(decoder) except KeyError: raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) @@ -1034,12 +1044,13 @@ def _HasFieldOrExtension(message, field_or_extension): return message.HasField(field_or_extension.name) -def _IsFieldOrExtensionInitialized(message, field): +def _IsFieldOrExtensionInitialized(message, field, errors=None): """Checks if a message field or extension is initialized. Args: message: The message which contains the field or extension. field: Field or extension to check. This must be a FieldDescriptor instance. + errors: Errors will be appended to it, if set to a meaningful value. Returns: True if the field/extension can be considered initialized. @@ -1047,6 +1058,8 @@ def _IsFieldOrExtensionInitialized(message, field): # If the field is required and is not set, it isn't initialized. if field.label == _FieldDescriptor.LABEL_REQUIRED: if not _HasFieldOrExtension(message, field): + if errors is not None: + errors.append('Required field %s is not set.' % field.full_name) return False # If the field is optional and is not set, or if it @@ -1062,7 +1075,27 @@ def _IsFieldOrExtensionInitialized(message, field): # If all submessages in this field are initialized, the field is # considered initialized. for message in messages: - if not message.IsInitialized(): + if not _InternalIsInitialized(message, errors): + return False + return True + + +def _InternalIsInitialized(message, errors=None): + """Checks if all required fields of a message are set. + + Args: + message: The message to check. + errors: If set, initialization errors will be appended to it. + + Returns: + True iff the specified message has all required fields set. + """ + fields_and_extensions = [] + fields_and_extensions.extend(message.DESCRIPTOR.fields) + fields_and_extensions.extend( + [extension[0] for extension in message.Extensions._ListSetExtensions()]) + for field_or_extension in fields_and_extensions: + if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors): return False return True @@ -1082,25 +1115,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls): cls.MergeFromString = MergeFromString -def _AddIsInitializedMethod(message_descriptor, cls): +def _AddIsInitializedMethod(cls): """Adds the IsInitialized method to the protocol message class.""" - def IsInitialized(self): - fields_and_extensions = [] - fields_and_extensions.extend(message_descriptor.fields) - fields_and_extensions.extend( - self.Extensions._AllExtensionsByNumber().values()) - for field_or_extension in fields_and_extensions: - if not _IsFieldOrExtensionInitialized(self, field_or_extension): - return False - return True - cls.IsInitialized = IsInitialized + cls.IsInitialized = _InternalIsInitialized -def _AddMessageMethods(message_descriptor, cls): - """Adds implementations of all Message methods to cls.""" +def _MergeFieldOrExtension(destination_msg, field, value): + """Merges a specified message field into another message.""" + property_name = _PropertyName(field.name) + is_extension = field.is_extension - # TODO(robinson): Add support for remaining Message methods. + if not is_extension: + destination = getattr(destination_msg, property_name) + elif (field.label == _FieldDescriptor.LABEL_REPEATED or + field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + destination = destination_msg.Extensions[field] + # Case 1 - a composite field. + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for v in value: + destination.add().MergeFrom(v) + else: + destination.MergeFrom(value) + return + + # Case 2 - a repeated field. + if field.label == _FieldDescriptor.LABEL_REPEATED: + for v in value: + destination.append(v) + return + + # Case 3 - a singular field. + if is_extension: + destination_msg.Extensions[field] = value + else: + setattr(destination_msg, property_name, value) + + +def _AddMergeFromMethod(cls): + def MergeFrom(self, msg): + assert msg is not self + for field in msg.ListFields(): + _MergeFieldOrExtension(self, field[0], field[1]) + cls.MergeFrom = MergeFrom + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) _AddHasFieldMethod(cls) _AddClearFieldMethod(cls) @@ -1111,8 +1173,10 @@ def _AddMessageMethods(message_descriptor, cls): _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) _AddMergeFromStringMethod(message_descriptor, cls) - _AddIsInitializedMethod(message_descriptor, cls) + _AddIsInitializedMethod(cls) + _AddMergeFromMethod(cls) def _AddPrivateHelperMethods(cls): @@ -1440,7 +1504,7 @@ class _ExtensionDict(object): and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. - type_checker = _VALUE_CHECKERS[field.cpp_type] + type_checker = type_checkers.VALUE_CHECKERS[field.cpp_type] type_checker.CheckValue(value) self._values[handle_id] = value self._has_bits[handle_id] = True @@ -1561,174 +1625,3 @@ class _ExtensionDict(object): # be careful when we move away from having _known_extensions as a # deep-copied member of this object. return dict((f.number, f) for f in self._known_extensions.itervalues()) - - -# None of the typecheckers below make any attempt to guard against people -# subclassing builtin types and doing weird things. We're not trying to -# protect against malicious clients here, just people accidentally shooting -# themselves in the foot in obvious ways. - -class _TypeChecker(object): - - """Type checker used to catch type errors as early as possible - when the client is setting scalar fields in protocol messages. - """ - - def __init__(self, *acceptable_types): - self._acceptable_types = acceptable_types - - def CheckValue(self, proposed_value): - if not isinstance(proposed_value, self._acceptable_types): - message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), self._acceptable_types)) - raise TypeError(message) - - -# _IntValueChecker and its subclasses perform integer type-checks -# and bounds-checks. -class _IntValueChecker(object): - - """Checker used for integer fields. Performs type-check and range check.""" - - def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (int, long)): - message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (int, long))) - raise TypeError(message) - if not self._MIN <= proposed_value <= self._MAX: - raise ValueError('Value out of range: %d' % proposed_value) - -class _Int32ValueChecker(_IntValueChecker): - # We're sure to use ints instead of longs here since comparison may be more - # efficient. - _MIN = -2147483648 - _MAX = 2147483647 - -class _Uint32ValueChecker(_IntValueChecker): - _MIN = 0 - _MAX = (1 << 32) - 1 - -class _Int64ValueChecker(_IntValueChecker): - _MIN = -(1 << 63) - _MAX = (1 << 63) - 1 - -class _Uint64ValueChecker(_IntValueChecker): - _MIN = 0 - _MAX = (1 << 64) - 1 - - -# Type-checkers for all scalar CPPTYPEs. -_VALUE_CHECKERS = { - _FieldDescriptor.CPPTYPE_INT32: _Int32ValueChecker(), - _FieldDescriptor.CPPTYPE_INT64: _Int64ValueChecker(), - _FieldDescriptor.CPPTYPE_UINT32: _Uint32ValueChecker(), - _FieldDescriptor.CPPTYPE_UINT64: _Uint64ValueChecker(), - _FieldDescriptor.CPPTYPE_DOUBLE: _TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_FLOAT: _TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_BOOL: _TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_ENUM: _Int32ValueChecker(), - _FieldDescriptor.CPPTYPE_STRING: _TypeChecker(str), - } - - -# Map from field type to a function F, such that F(field_num, value) -# gives the total byte size for a value of the given type. This -# byte size includes tag information and any other additional space -# associated with serializing "value". -_TYPE_TO_BYTE_SIZE_FN = { - _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, - _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, - _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, - _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, - _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, - _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, - _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, - _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, - _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, - _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, - _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, - _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, - _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, - _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, - _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, - _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, - _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, - _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize - } - -# Maps from field type to an unbound Encoder method F, such that -# F(encoder, field_number, value) will append the serialization -# of a value of this type to the encoder. -_Encoder = encoder.Encoder -_TYPE_TO_SERIALIZE_METHOD = { - _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble, - _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat, - _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64, - _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64, - _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32, - _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64, - _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32, - _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool, - _FieldDescriptor.TYPE_STRING: _Encoder.AppendString, - _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup, - _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage, - _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes, - _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32, - _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum, - _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32, - _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64, - _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32, - _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64, - } - -# Maps from field type to expected wiretype. -_FIELD_TYPE_TO_WIRE_TYPE = { - _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, - _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, - _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, - _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, - _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_STRING: - wire_format.WIRETYPE_LENGTH_DELIMITED, - _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, - _FieldDescriptor.TYPE_MESSAGE: - wire_format.WIRETYPE_LENGTH_DELIMITED, - _FieldDescriptor.TYPE_BYTES: - wire_format.WIRETYPE_LENGTH_DELIMITED, - _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, - _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, - _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, - _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, - } - -# Maps from field type to an unbound Decoder method F, -# such that F(decoder) will read a field of the requested type. -# -# Note that Message and Group are intentionally missing here. -# They're handled by _RecursivelyMerge(). -_Decoder = decoder.Decoder -_TYPE_TO_DESERIALIZE_METHOD = { - _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble, - _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat, - _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64, - _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64, - _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32, - _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64, - _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32, - _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool, - _FieldDescriptor.TYPE_STRING: _Decoder.ReadString, - _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes, - _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32, - _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum, - _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32, - _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64, - _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32, - _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64, - } |