aboutsummaryrefslogtreecommitdiff
path: root/compiler/src/dotty/tools/dotc/ast/Desugar.scala
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/src/dotty/tools/dotc/ast/Desugar.scala')
-rw-r--r--compiler/src/dotty/tools/dotc/ast/Desugar.scala90
1 files changed, 55 insertions, 35 deletions
diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala
index 87994a87b..863c10cb0 100644
--- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala
+++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala
@@ -14,6 +14,7 @@ import reporting.diagnostic.messages._
object desugar {
import untpd._
+ import DesugarEnums._
/** Tags a .withFilter call generated by desugaring a for expression.
* Such calls can alternatively be rewritten to use filter.
@@ -263,7 +264,9 @@ object desugar {
val className = checkNotReservedName(cdef).asTypeName
val impl @ Template(constr0, parents, self, _) = cdef.rhs
val mods = cdef.mods
- val companionMods = mods.withFlags((mods.flags & AccessFlags).toCommonFlags)
+ val companionMods = mods
+ .withFlags((mods.flags & AccessFlags).toCommonFlags)
+ .withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase]))
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match {
case meth: DefDef => (meth, Nil)
@@ -288,17 +291,22 @@ object desugar {
}
val isCaseClass = mods.is(Case) && !mods.is(Module)
+ val isEnum = mods.hasMod[Mod.Enum]
+ val isEnumCase = isLegalEnumCase(cdef)
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
- val constrTparams = constr1.tparams map toDefParam
+ val originalTparams =
+ if (isEnumCase && parents.isEmpty) reconstitutedEnumTypeParams(cdef.pos.startPos)
+ else constr1.tparams
+ val originalVparamss = constr1.vparamss
+ val constrTparams = originalTparams.map(toDefParam)
val constrVparamss =
- if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty
- if (isCaseClass)
- ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
+ if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
+ if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
ListOfNil
}
- else constr1.vparamss.nestedMap(toDefParam)
+ else originalVparamss.nestedMap(toDefParam)
val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)
// Add constructor type parameters and evidence implicit parameters
@@ -312,21 +320,22 @@ object desugar {
stat
}
- val derivedTparams = constrTparams map derivedTypeParam
+ val derivedTparams =
+ if (isEnumCase) constrTparams else constrTparams map derivedTypeParam
val derivedVparamss = constrVparamss nestedMap derivedTermParam
val arity = constrVparamss.head.length
- var classTycon: Tree = EmptyTree
+ val classTycon: Tree = new TypeRefTree // watching is set at end of method
- // a reference to the class type, with all parameters given.
- val classTypeRef/*: Tree*/ = {
- // -language:keepUnions difference: classTypeRef needs type annotation, otherwise
- // infers Ident | AppliedTypeTree, which
- // renders the :\ in companions below untypable.
- classTycon = (new TypeRefTree) withPos cdef.pos.startPos // watching is set at end of method
- val tparams = impl.constr.tparams
- if (tparams.isEmpty) classTycon else AppliedTypeTree(classTycon, tparams map refOfDef)
- }
+ def appliedRef(tycon: Tree) =
+ (if (constrTparams.isEmpty) tycon
+ else AppliedTypeTree(tycon, constrTparams map refOfDef))
+ .withPos(cdef.pos.startPos)
+
+ // a reference to the class type bound by `cdef`, with type parameters coming from the constructor
+ val classTypeRef = appliedRef(classTycon)
+ // a refereence to `enumClass`, with type parameters coming from the constructor
+ lazy val enumClassTypeRef = appliedRef(enumClassRef)
// new C[Ts](paramss)
lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef)
@@ -374,7 +383,9 @@ object desugar {
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
.withMods(synthetic) :: Nil
}
- copyMeths ::: productElemMeths.toList
+
+ val enumTagMeths = if (isEnumCase) enumTagMeth :: Nil else Nil
+ copyMeths ::: enumTagMeths ::: productElemMeths.toList
}
else Nil
@@ -387,8 +398,12 @@ object desugar {
// Case classes and case objects get a ProductN parent
var parents1 = parents
+ if (isEnumCase && parents.isEmpty)
+ parents1 = enumClassTypeRef :: Nil
if (mods.is(Case) && arity <= Definitions.MaxTupleArity)
- parents1 = parents1 :+ productConstr(arity)
+ parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that?
+ if (isEnum)
+ parents1 = parents1 :+ ref(defn.EnumType)
// The thicket which is the desugared version of the companion object
// synthetic object C extends parentTpt { defs }
@@ -419,9 +434,11 @@ object desugar {
else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function(vparams map (_.tpt), restpe))
val applyMeths =
if (mods is Abstract) Nil
- else
- DefDef(nme.apply, derivedTparams, derivedVparamss, TypeTree(), creatorExpr)
+ else {
+ val restpe = if (isEnumCase) enumClassTypeRef else TypeTree()
+ DefDef(nme.apply, derivedTparams, derivedVparamss, restpe, creatorExpr)
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil
+ }
val unapplyMeth = {
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
@@ -464,15 +481,15 @@ object desugar {
else cpy.ValDef(self)(tpt = selfType).withMods(self.mods | SelfName)
}
- val cdef1 = {
- val originalTparams = constr1.tparams.toIterator
- val originalVparams = constr1.vparamss.toIterator.flatten
- val tparamAccessors = derivedTparams.map(_.withMods(originalTparams.next.mods))
+ val cdef1 = addEnumFlags {
+ val originalTparamsIt = originalTparams.toIterator
+ val originalVparamsIt = originalVparamss.toIterator.flatten
+ val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next.mods))
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
val vparamAccessors = derivedVparamss match {
case first :: rest =>
- first.map(_.withMods(originalVparams.next.mods | caseAccessor)) ++
- rest.flatten.map(_.withMods(originalVparams.next.mods))
+ first.map(_.withMods(originalVparamsIt.next.mods | caseAccessor)) ++
+ rest.flatten.map(_.withMods(originalVparamsIt.next.mods))
case _ =>
Nil
}
@@ -503,23 +520,26 @@ object desugar {
*/
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
val moduleName = checkNotReservedName(mdef).asTermName
- val tmpl = mdef.impl
+ val impl = mdef.impl
val mods = mdef.mods
+ lazy val isEnumCase = isLegalEnumCase(mdef)
if (mods is Package)
- PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, tmpl).withMods(mods &~ Package) :: Nil)
+ PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil)
+ else if (isEnumCase)
+ expandEnumModule(moduleName, impl, mods, mdef.pos)
else {
val clsName = moduleName.moduleClassName
val clsRef = Ident(clsName)
val modul = ValDef(moduleName, clsRef, New(clsRef, Nil))
.withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags)
.withPos(mdef.pos)
- val ValDef(selfName, selfTpt, _) = tmpl.self
- val selfMods = tmpl.self.mods
- if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), tmpl.self.pos)
- val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), tmpl.self.rhs)
+ val ValDef(selfName, selfTpt, _) = impl.self
+ val selfMods = impl.self.mods
+ if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), impl.self.pos)
+ val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), impl.self.rhs)
.withMods(selfMods)
- .withPos(tmpl.self.pos orElse tmpl.pos.startPos)
- val clsTmpl = cpy.Template(tmpl)(self = clsSelf, body = tmpl.body)
+ .withPos(impl.self.pos orElse impl.pos.startPos)
+ val clsTmpl = cpy.Template(impl)(self = clsSelf, body = impl.body)
val cls = TypeDef(clsName, clsTmpl)
.withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags)
Thicket(modul, classDef(cls).withPos(mdef.pos))