aboutsummaryrefslogtreecommitdiff
path: root/src/ProtoGen/UmbrellaClassGenerator.cs
blob: 732b6f8da711e3dad1e92926040bebbdfa02ef98 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
using System;
using System.Collections;
using System.Collections.Generic;
using Google.ProtocolBuffers.DescriptorProtos;
using Google.ProtocolBuffers.Descriptors;

namespace Google.ProtocolBuffers.ProtoGen {
  /// <summary>
  /// Generator for the class describing the .proto file in general,
  /// containing things like the message descriptor.
  /// </summary>
  internal sealed class UmbrellaClassGenerator : SourceGeneratorBase<FileDescriptor>, ISourceGenerator {

    internal UmbrellaClassGenerator(FileDescriptor descriptor)
      : base(descriptor) {
    }

    // Recursively searches the given message to see if it contains any extensions.
    private static bool UsesExtensions(IMessage message) {
      // We conservatively assume that unknown fields are extensions.
      if (message.UnknownFields.FieldDictionary.Count > 0) {
        return true;
      }

      foreach (KeyValuePair<FieldDescriptor, object> keyValue in message.AllFields) {
        FieldDescriptor field = keyValue.Key;
        if (field.IsExtension) {
          return true;
        }
        if (field.MappedType == MappedType.Message) {
          if (field.IsRepeated) {
            foreach (IMessage subMessage in (IEnumerable)keyValue.Value) {
              if (UsesExtensions(subMessage)) {
                return true;
              }
            }
          } else {
            if (UsesExtensions((IMessage)keyValue.Value)) {
              return true;
            }
          }
        }
      }
      return false;
    }

    public string UmbrellaClassName {
      get { throw new NotImplementedException(); }
    }

    public void Generate(TextGenerator writer) {
      WriteIntroduction(writer);
      WriteExtensionRegistration(writer);
      WriteChildren(writer, "Extensions", Descriptor.Extensions);
      writer.WriteLine("#region Static variables");
      foreach (MessageDescriptor message in Descriptor.MessageTypes) {
        new MessageGenerator(message).GenerateStaticVariables(writer);
      }
      writer.WriteLine("#endregion");
      WriteDescriptor(writer);
      // The class declaration either gets closed before or after the children are written.
      if (!Descriptor.CSharpOptions.NestClasses) {
        writer.Outdent();
        writer.WriteLine("}");
      }
      WriteChildren(writer, "Enums", Descriptor.EnumTypes);
      WriteChildren(writer, "Messages", Descriptor.MessageTypes);
      WriteChildren(writer, "Services", Descriptor.Services);
      if (Descriptor.CSharpOptions.NestClasses) {
        writer.Outdent();
        writer.WriteLine("}");
      }
      if (Descriptor.CSharpOptions.Namespace != "") {
        writer.Outdent();
        writer.WriteLine("}");
      }
    }

    private void WriteIntroduction(TextGenerator writer) {
      writer.WriteLine("// Generated by the protocol buffer compiler.  DO NOT EDIT!");
      writer.WriteLine();
      Helpers.WriteNamespaces(writer);

      if (Descriptor.CSharpOptions.Namespace != "") {
        writer.WriteLine("namespace {0} {{", Descriptor.CSharpOptions.Namespace);
        writer.Indent();
        writer.WriteLine();
      }

      writer.WriteLine("{0} static partial class {1} {{", ClassAccessLevel, Descriptor.CSharpOptions.UmbrellaClassname);
      writer.WriteLine();
      writer.Indent();
    }

    private void WriteExtensionRegistration(TextGenerator writer) {
      writer.WriteLine("#region Extension registration");
      writer.WriteLine("public static void RegisterAllExtensions(pb::ExtensionRegistry registry) {");
      writer.Indent();
      foreach (FieldDescriptor extension in Descriptor.Extensions) {
        new ExtensionGenerator(extension).GenerateExtensionRegistrationCode(writer);
      }
      foreach (MessageDescriptor message in Descriptor.MessageTypes) {
        new MessageGenerator(message).GenerateExtensionRegistrationCode(writer);
      }
      writer.Outdent();
      writer.WriteLine("}");
      writer.WriteLine("#endregion");
    }

    private void WriteDescriptor(TextGenerator writer) {
      writer.WriteLine("#region Descriptor");

      writer.WriteLine("public static pbd::FileDescriptor Descriptor {");
      writer.WriteLine("  get { return descriptor; }");
      writer.WriteLine("}");
      writer.WriteLine("private static pbd::FileDescriptor descriptor;");
      writer.WriteLine();
      writer.WriteLine("static {0}() {{", Descriptor.CSharpOptions.UmbrellaClassname);
      writer.Indent();
      writer.WriteLine("byte[] descriptorData = global::System.Convert.FromBase64String(");
      writer.Indent();
      writer.Indent();

      // TODO(jonskeet): Consider a C#-escaping format here instead of just Base64.
      byte[] bytes = Descriptor.Proto.ToByteArray();
      string base64 = Convert.ToBase64String(bytes);

      while (base64.Length > 60) {
        writer.WriteLine("\"{0}\" + ", base64.Substring(0, 60));
        base64 = base64.Substring(60);
      }
      writer.WriteLine("\"{0}\");", base64);
      writer.Outdent();
      writer.Outdent();
      writer.WriteLine("pbd::FileDescriptor.InternalDescriptorAssigner assigner = delegate(pbd::FileDescriptor root) {");
      writer.Indent();
      writer.WriteLine("descriptor = root;");
      foreach (MessageDescriptor message in Descriptor.MessageTypes) {
        new MessageGenerator(message).GenerateStaticVariableInitializers(writer);
      }
      foreach (FieldDescriptor extension in Descriptor.Extensions) {
        new ExtensionGenerator(extension).GenerateStaticVariableInitializers(writer);
      }

      if (UsesExtensions(Descriptor.Proto)) {
        // Must construct an ExtensionRegistry containing all possible extensions
        // and return it.
        writer.WriteLine("pb::ExtensionRegistry registry = pb::ExtensionRegistry.CreateInstance();");
        writer.WriteLine("RegisterAllExtensions(registry);");
        foreach (FileDescriptor dependency in Descriptor.Dependencies) {
          writer.WriteLine("{0}.RegisterAllExtensions(registry);", DescriptorUtil.GetFullUmbrellaClassName(dependency));
        }
        writer.WriteLine("return registry;");
      } else {
        writer.WriteLine("return null;");
      }
      writer.Outdent();
      writer.WriteLine("};");

      // -----------------------------------------------------------------
      // Invoke internalBuildGeneratedFileFrom() to build the file.
      writer.WriteLine("pbd::FileDescriptor.InternalBuildGeneratedFileFrom(descriptorData,");
      writer.WriteLine("    new pbd::FileDescriptor[] {");
      foreach (FileDescriptor dependency in Descriptor.Dependencies) {
        writer.WriteLine("    {0}.Descriptor, ", DescriptorUtil.GetFullUmbrellaClassName(dependency));
      }
      writer.WriteLine("    }, assigner);");
      writer.Outdent();
      writer.WriteLine("}");
      writer.WriteLine("#endregion");
      writer.WriteLine();
    }
  }
}