diff options
Diffstat (limited to 'compiler/src/dotty/tools/dotc/ast/Desugar.scala')
-rw-r--r-- | compiler/src/dotty/tools/dotc/ast/Desugar.scala | 90 |
1 files changed, 55 insertions, 35 deletions
diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 87994a87b..863c10cb0 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -14,6 +14,7 @@ import reporting.diagnostic.messages._ object desugar { import untpd._ + import DesugarEnums._ /** Tags a .withFilter call generated by desugaring a for expression. * Such calls can alternatively be rewritten to use filter. @@ -263,7 +264,9 @@ object desugar { val className = checkNotReservedName(cdef).asTypeName val impl @ Template(constr0, parents, self, _) = cdef.rhs val mods = cdef.mods - val companionMods = mods.withFlags((mods.flags & AccessFlags).toCommonFlags) + val companionMods = mods + .withFlags((mods.flags & AccessFlags).toCommonFlags) + .withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase])) val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match { case meth: DefDef => (meth, Nil) @@ -288,17 +291,22 @@ object desugar { } val isCaseClass = mods.is(Case) && !mods.is(Module) + val isEnum = mods.hasMod[Mod.Enum] + val isEnumCase = isLegalEnumCase(cdef) val isValueClass = parents.nonEmpty && isAnyVal(parents.head) // This is not watertight, but `extends AnyVal` will be replaced by `inline` later. - val constrTparams = constr1.tparams map toDefParam + val originalTparams = + if (isEnumCase && parents.isEmpty) reconstitutedEnumTypeParams(cdef.pos.startPos) + else constr1.tparams + val originalVparamss = constr1.vparamss + val constrTparams = originalTparams.map(toDefParam) val constrVparamss = - if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty - if (isCaseClass) - ctx.error(CaseClassMissingParamList(cdef), cdef.namePos) + if (originalVparamss.isEmpty) { // ensure parameter list is non-empty + if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos) ListOfNil } - else constr1.vparamss.nestedMap(toDefParam) + else originalVparamss.nestedMap(toDefParam) val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss) // Add constructor type parameters and evidence implicit parameters @@ -312,21 +320,22 @@ object desugar { stat } - val derivedTparams = constrTparams map derivedTypeParam + val derivedTparams = + if (isEnumCase) constrTparams else constrTparams map derivedTypeParam val derivedVparamss = constrVparamss nestedMap derivedTermParam val arity = constrVparamss.head.length - var classTycon: Tree = EmptyTree + val classTycon: Tree = new TypeRefTree // watching is set at end of method - // a reference to the class type, with all parameters given. - val classTypeRef/*: Tree*/ = { - // -language:keepUnions difference: classTypeRef needs type annotation, otherwise - // infers Ident | AppliedTypeTree, which - // renders the :\ in companions below untypable. - classTycon = (new TypeRefTree) withPos cdef.pos.startPos // watching is set at end of method - val tparams = impl.constr.tparams - if (tparams.isEmpty) classTycon else AppliedTypeTree(classTycon, tparams map refOfDef) - } + def appliedRef(tycon: Tree) = + (if (constrTparams.isEmpty) tycon + else AppliedTypeTree(tycon, constrTparams map refOfDef)) + .withPos(cdef.pos.startPos) + + // a reference to the class type bound by `cdef`, with type parameters coming from the constructor + val classTypeRef = appliedRef(classTycon) + // a refereence to `enumClass`, with type parameters coming from the constructor + lazy val enumClassTypeRef = appliedRef(enumClassRef) // new C[Ts](paramss) lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef) @@ -374,7 +383,9 @@ object desugar { DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr) .withMods(synthetic) :: Nil } - copyMeths ::: productElemMeths.toList + + val enumTagMeths = if (isEnumCase) enumTagMeth :: Nil else Nil + copyMeths ::: enumTagMeths ::: productElemMeths.toList } else Nil @@ -387,8 +398,12 @@ object desugar { // Case classes and case objects get a ProductN parent var parents1 = parents + if (isEnumCase && parents.isEmpty) + parents1 = enumClassTypeRef :: Nil if (mods.is(Case) && arity <= Definitions.MaxTupleArity) - parents1 = parents1 :+ productConstr(arity) + parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that? + if (isEnum) + parents1 = parents1 :+ ref(defn.EnumType) // The thicket which is the desugared version of the companion object // synthetic object C extends parentTpt { defs } @@ -419,9 +434,11 @@ object desugar { else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function(vparams map (_.tpt), restpe)) val applyMeths = if (mods is Abstract) Nil - else - DefDef(nme.apply, derivedTparams, derivedVparamss, TypeTree(), creatorExpr) + else { + val restpe = if (isEnumCase) enumClassTypeRef else TypeTree() + DefDef(nme.apply, derivedTparams, derivedVparamss, restpe, creatorExpr) .withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil + } val unapplyMeth = { val unapplyParam = makeSyntheticParameter(tpt = classTypeRef) val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name) @@ -464,15 +481,15 @@ object desugar { else cpy.ValDef(self)(tpt = selfType).withMods(self.mods | SelfName) } - val cdef1 = { - val originalTparams = constr1.tparams.toIterator - val originalVparams = constr1.vparamss.toIterator.flatten - val tparamAccessors = derivedTparams.map(_.withMods(originalTparams.next.mods)) + val cdef1 = addEnumFlags { + val originalTparamsIt = originalTparams.toIterator + val originalVparamsIt = originalVparamss.toIterator.flatten + val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next.mods)) val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags val vparamAccessors = derivedVparamss match { case first :: rest => - first.map(_.withMods(originalVparams.next.mods | caseAccessor)) ++ - rest.flatten.map(_.withMods(originalVparams.next.mods)) + first.map(_.withMods(originalVparamsIt.next.mods | caseAccessor)) ++ + rest.flatten.map(_.withMods(originalVparamsIt.next.mods)) case _ => Nil } @@ -503,23 +520,26 @@ object desugar { */ def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = { val moduleName = checkNotReservedName(mdef).asTermName - val tmpl = mdef.impl + val impl = mdef.impl val mods = mdef.mods + lazy val isEnumCase = isLegalEnumCase(mdef) if (mods is Package) - PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, tmpl).withMods(mods &~ Package) :: Nil) + PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil) + else if (isEnumCase) + expandEnumModule(moduleName, impl, mods, mdef.pos) else { val clsName = moduleName.moduleClassName val clsRef = Ident(clsName) val modul = ValDef(moduleName, clsRef, New(clsRef, Nil)) .withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags) .withPos(mdef.pos) - val ValDef(selfName, selfTpt, _) = tmpl.self - val selfMods = tmpl.self.mods - if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), tmpl.self.pos) - val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), tmpl.self.rhs) + val ValDef(selfName, selfTpt, _) = impl.self + val selfMods = impl.self.mods + if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), impl.self.pos) + val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), impl.self.rhs) .withMods(selfMods) - .withPos(tmpl.self.pos orElse tmpl.pos.startPos) - val clsTmpl = cpy.Template(tmpl)(self = clsSelf, body = tmpl.body) + .withPos(impl.self.pos orElse impl.pos.startPos) + val clsTmpl = cpy.Template(impl)(self = clsSelf, body = impl.body) val cls = TypeDef(clsName, clsTmpl) .withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags) Thicket(modul, classDef(cls).withPos(mdef.pos)) |