aboutsummaryrefslogtreecommitdiff
path: root/java/core/src/main/java/com/google/protobuf/ExtensionRegistry.java
blob: aeeaee53e211291f679f0d02a9bfcea6cd20ac4e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc.  All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
//     * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
//     * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

package com.google.protobuf;

import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * A table of known extensions, searchable by name or field number. When parsing a protocol message
 * that might have extensions, you must provide an {@code ExtensionRegistry} in which you have
 * registered any extensions that you want to be able to parse. Otherwise, those extensions will
 * just be treated like unknown fields.
 *
 * <p>For example, if you had the {@code .proto} file:
 *
 * <pre>
 * option java_class = "MyProto";
 *
 * message Foo {
 *   extensions 1000 to max;
 * }
 *
 * extend Foo {
 *   optional int32 bar;
 * }
 * </pre>
 *
 * Then you might write code like:
 *
 * <pre>
 * ExtensionRegistry registry = ExtensionRegistry.newInstance();
 * registry.add(MyProto.bar);
 * MyProto.Foo message = MyProto.Foo.parseFrom(input, registry);
 * </pre>
 *
 * <p>Background:
 *
 * <p>You might wonder why this is necessary. Two alternatives might come to mind. First, you might
 * imagine a system where generated extensions are automatically registered when their containing
 * classes are loaded. This is a popular technique, but is bad design; among other things, it
 * creates a situation where behavior can change depending on what classes happen to be loaded. It
 * also introduces a security vulnerability, because an unprivileged class could cause its code to
 * be called unexpectedly from a privileged class by registering itself as an extension of the right
 * type.
 *
 * <p>Another option you might consider is lazy parsing: do not parse an extension until it is first
 * requested, at which point the caller must provide a type to use. This introduces a different set
 * of problems. First, it would require a mutex lock any time an extension was accessed, which would
 * be slow. Second, corrupt data would not be detected until first access, at which point it would
 * be much harder to deal with it. Third, it could violate the expectation that message objects are
 * immutable, since the type provided could be any arbitrary message class. An unprivileged user
 * could take advantage of this to inject a mutable object into a message belonging to privileged
 * code and create mischief.
 *
 * @author kenton@google.com Kenton Varda
 */
public class ExtensionRegistry extends ExtensionRegistryLite {
  /** Construct a new, empty instance. */
  public static ExtensionRegistry newInstance() {
    return new ExtensionRegistry();
  }

  /** Get the unmodifiable singleton empty instance. */
  public static ExtensionRegistry getEmptyRegistry() {
    return EMPTY_REGISTRY;
  }


  /** Returns an unmodifiable view of the registry. */
  @Override
  public ExtensionRegistry getUnmodifiable() {
    return new ExtensionRegistry(this);
  }

  /** A (Descriptor, Message) pair, returned by lookup methods. */
  public static final class ExtensionInfo {
    /** The extension's descriptor. */
    public final FieldDescriptor descriptor;

    /**
     * A default instance of the extension's type, if it has a message type. Otherwise, {@code
     * null}.
     */
    public final Message defaultInstance;

    private ExtensionInfo(final FieldDescriptor descriptor) {
      this.descriptor = descriptor;
      defaultInstance = null;
    }

    private ExtensionInfo(final FieldDescriptor descriptor, final Message defaultInstance) {
      this.descriptor = descriptor;
      this.defaultInstance = defaultInstance;
    }
  }

  /** Deprecated. Use {@link #findImmutableExtensionByName(String)} instead. */
  @Deprecated
  public ExtensionInfo findExtensionByName(final String fullName) {
    return findImmutableExtensionByName(fullName);
  }

  /**
   * Find an extension for immutable APIs by fully-qualified field name, in the proto namespace.
   * i.e. {@code result.descriptor.fullName()} will match {@code fullName} if a match is found.
   *
   * @return Information about the extension if found, or {@code null} otherwise.
   */
  public ExtensionInfo findImmutableExtensionByName(final String fullName) {
    return immutableExtensionsByName.get(fullName);
  }

  /**
   * Find an extension for mutable APIs by fully-qualified field name, in the proto namespace. i.e.
   * {@code result.descriptor.fullName()} will match {@code fullName} if a match is found.
   *
   * @return Information about the extension if found, or {@code null} otherwise.
   */
  public ExtensionInfo findMutableExtensionByName(final String fullName) {
    return mutableExtensionsByName.get(fullName);
  }

  /** Deprecated. Use {@link #findImmutableExtensionByNumber( Descriptors.Descriptor, int)} */
  @Deprecated
  public ExtensionInfo findExtensionByNumber(
      final Descriptor containingType, final int fieldNumber) {
    return findImmutableExtensionByNumber(containingType, fieldNumber);
  }

  /**
   * Find an extension by containing type and field number for immutable APIs.
   *
   * @return Information about the extension if found, or {@code null} otherwise.
   */
  public ExtensionInfo findImmutableExtensionByNumber(
      final Descriptor containingType, final int fieldNumber) {
    return immutableExtensionsByNumber.get(new DescriptorIntPair(containingType, fieldNumber));
  }

  /**
   * Find an extension by containing type and field number for mutable APIs.
   *
   * @return Information about the extension if found, or {@code null} otherwise.
   */
  public ExtensionInfo findMutableExtensionByNumber(
      final Descriptor containingType, final int fieldNumber) {
    return mutableExtensionsByNumber.get(new DescriptorIntPair(containingType, fieldNumber));
  }

  /**
   * Find all extensions for mutable APIs by fully-qualified name of extended class. Note that this
   * method is more computationally expensive than getting a single extension by name or number.
   *
   * @return Information about the extensions found, or {@code null} if there are none.
   */
  public Set<ExtensionInfo> getAllMutableExtensionsByExtendedType(final String fullName) {
    HashSet<ExtensionInfo> extensions = new HashSet<ExtensionInfo>();
    for (DescriptorIntPair pair : mutableExtensionsByNumber.keySet()) {
      if (pair.descriptor.getFullName().equals(fullName)) {
        extensions.add(mutableExtensionsByNumber.get(pair));
      }
    }
    return extensions;
  }

  /**
   * Find all extensions for immutable APIs by fully-qualified name of extended class. Note that
   * this method is more computationally expensive than getting a single extension by name or
   * number.
   *
   * @return Information about the extensions found, or {@code null} if there are none.
   */
  public Set<ExtensionInfo> getAllImmutableExtensionsByExtendedType(final String fullName) {
    HashSet<ExtensionInfo> extensions = new HashSet<ExtensionInfo>();
    for (DescriptorIntPair pair : immutableExtensionsByNumber.keySet()) {
      if (pair.descriptor.getFullName().equals(fullName)) {
        extensions.add(immutableExtensionsByNumber.get(pair));
      }
    }
    return extensions;
  }

  /** Add an extension from a generated file to the registry. */
  public void add(final Extension<?, ?> extension) {
    if (extension.getExtensionType() != Extension.ExtensionType.IMMUTABLE
        && extension.getExtensionType() != Extension.ExtensionType.MUTABLE) {
      // do not support other extension types. ignore
      return;
    }
    add(newExtensionInfo(extension), extension.getExtensionType());
  }

  /** Add an extension from a generated file to the registry. */
  public void add(final GeneratedMessage.GeneratedExtension<?, ?> extension) {
    add((Extension<?, ?>) extension);
  }

  static ExtensionInfo newExtensionInfo(final Extension<?, ?> extension) {
    if (extension.getDescriptor().getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
      if (extension.getMessageDefaultInstance() == null) {
        throw new IllegalStateException(
            "Registered message-type extension had null default instance: "
                + extension.getDescriptor().getFullName());
      }
      return new ExtensionInfo(
          extension.getDescriptor(), (Message) extension.getMessageDefaultInstance());
    } else {
      return new ExtensionInfo(extension.getDescriptor(), null);
    }
  }

  /** Add a non-message-type extension to the registry by descriptor. */
  public void add(final FieldDescriptor type) {
    if (type.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
      throw new IllegalArgumentException(
          "ExtensionRegistry.add() must be provided a default instance when "
              + "adding an embedded message extension.");
    }
    ExtensionInfo info = new ExtensionInfo(type, null);
    add(info, Extension.ExtensionType.IMMUTABLE);
    add(info, Extension.ExtensionType.MUTABLE);
  }

  /** Add a message-type extension to the registry by descriptor. */
  public void add(final FieldDescriptor type, final Message defaultInstance) {
    if (type.getJavaType() != FieldDescriptor.JavaType.MESSAGE) {
      throw new IllegalArgumentException(
          "ExtensionRegistry.add() provided a default instance for a non-message extension.");
    }
      add(new ExtensionInfo(type, defaultInstance), Extension.ExtensionType.IMMUTABLE);
  }

  // =================================================================
  // Private stuff.

  private ExtensionRegistry() {
    this.immutableExtensionsByName = new HashMap<String, ExtensionInfo>();
    this.mutableExtensionsByName = new HashMap<String, ExtensionInfo>();
    this.immutableExtensionsByNumber = new HashMap<DescriptorIntPair, ExtensionInfo>();
    this.mutableExtensionsByNumber = new HashMap<DescriptorIntPair, ExtensionInfo>();
  }

  private ExtensionRegistry(ExtensionRegistry other) {
    super(other);
    this.immutableExtensionsByName = Collections.unmodifiableMap(other.immutableExtensionsByName);
    this.mutableExtensionsByName = Collections.unmodifiableMap(other.mutableExtensionsByName);
    this.immutableExtensionsByNumber =
        Collections.unmodifiableMap(other.immutableExtensionsByNumber);
    this.mutableExtensionsByNumber = Collections.unmodifiableMap(other.mutableExtensionsByNumber);
  }

  private final Map<String, ExtensionInfo> immutableExtensionsByName;
  private final Map<String, ExtensionInfo> mutableExtensionsByName;
  private final Map<DescriptorIntPair, ExtensionInfo> immutableExtensionsByNumber;
  private final Map<DescriptorIntPair, ExtensionInfo> mutableExtensionsByNumber;

  ExtensionRegistry(boolean empty) {
    super(EMPTY_REGISTRY_LITE);
    this.immutableExtensionsByName = Collections.<String, ExtensionInfo>emptyMap();
    this.mutableExtensionsByName = Collections.<String, ExtensionInfo>emptyMap();
    this.immutableExtensionsByNumber = Collections.<DescriptorIntPair, ExtensionInfo>emptyMap();
    this.mutableExtensionsByNumber = Collections.<DescriptorIntPair, ExtensionInfo>emptyMap();
  }

  static final ExtensionRegistry EMPTY_REGISTRY = new ExtensionRegistry(true);

  private void add(final ExtensionInfo extension, final Extension.ExtensionType extensionType) {
    if (!extension.descriptor.isExtension()) {
      throw new IllegalArgumentException(
          "ExtensionRegistry.add() was given a FieldDescriptor for a regular "
              + "(non-extension) field.");
    }

    Map<String, ExtensionInfo> extensionsByName;
    Map<DescriptorIntPair, ExtensionInfo> extensionsByNumber;
    switch (extensionType) {
      case IMMUTABLE:
        extensionsByName = immutableExtensionsByName;
        extensionsByNumber = immutableExtensionsByNumber;
        break;
      case MUTABLE:
        extensionsByName = mutableExtensionsByName;
        extensionsByNumber = mutableExtensionsByNumber;
        break;
      default:
        // Ignore the unknown supported type.
        return;
    }

    extensionsByName.put(extension.descriptor.getFullName(), extension);
    extensionsByNumber.put(
        new DescriptorIntPair(
            extension.descriptor.getContainingType(), extension.descriptor.getNumber()),
        extension);

    final FieldDescriptor field = extension.descriptor;
    if (field.getContainingType().getOptions().getMessageSetWireFormat()
        && field.getType() == FieldDescriptor.Type.MESSAGE
        && field.isOptional()
        && field.getExtensionScope() == field.getMessageType()) {
      // This is an extension of a MessageSet type defined within the extension
      // type's own scope.  For backwards-compatibility, allow it to be looked
      // up by type name.
      extensionsByName.put(field.getMessageType().getFullName(), extension);
    }
  }

  /** A (GenericDescriptor, int) pair, used as a map key. */
  private static final class DescriptorIntPair {
    private final Descriptor descriptor;
    private final int number;

    DescriptorIntPair(final Descriptor descriptor, final int number) {
      this.descriptor = descriptor;
      this.number = number;
    }

    @Override
    public int hashCode() {
      return descriptor.hashCode() * ((1 << 16) - 1) + number;
    }

    @Override
    public boolean equals(final Object obj) {
      if (!(obj instanceof DescriptorIntPair)) {
        return false;
      }
      final DescriptorIntPair other = (DescriptorIntPair) obj;
      return descriptor == other.descriptor && number == other.number;
    }
  }
}