diff options
author | Martin Odersky <odersky@gmail.com> | 2017-02-08 21:23:15 +1100 |
---|---|---|
committer | Martin Odersky <odersky@gmail.com> | 2017-04-04 13:28:44 +0200 |
commit | cea243a4fc38dcc8831000d1066e10362df37576 (patch) | |
tree | 38bf1795d9c0d4ce973e405b842d80cec40ae215 | |
parent | 41d83d42650d0c0b54c47c1a9043d0b92315aa4e (diff) | |
download | dotty-cea243a4fc38dcc8831000d1066e10362df37576.tar.gz dotty-cea243a4fc38dcc8831000d1066e10362df37576.tar.bz2 dotty-cea243a4fc38dcc8831000d1066e10362df37576.zip |
Implement enum desugaring
-rw-r--r-- | compiler/src/dotty/tools/dotc/ast/Desugar.scala | 90 | ||||
-rw-r--r-- | compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala | 124 | ||||
-rw-r--r-- | compiler/src/dotty/tools/dotc/core/Definitions.scala | 4 | ||||
-rw-r--r-- | compiler/src/dotty/tools/dotc/core/StdNames.scala | 7 | ||||
-rw-r--r-- | library/src/scala/Enum.scala | 8 | ||||
-rw-r--r-- | library/src/scala/runtime/EnumValues.scala | 18 |
6 files changed, 215 insertions, 36 deletions
diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 87994a87b..863c10cb0 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -14,6 +14,7 @@ import reporting.diagnostic.messages._ object desugar { import untpd._ + import DesugarEnums._ /** Tags a .withFilter call generated by desugaring a for expression. * Such calls can alternatively be rewritten to use filter. @@ -263,7 +264,9 @@ object desugar { val className = checkNotReservedName(cdef).asTypeName val impl @ Template(constr0, parents, self, _) = cdef.rhs val mods = cdef.mods - val companionMods = mods.withFlags((mods.flags & AccessFlags).toCommonFlags) + val companionMods = mods + .withFlags((mods.flags & AccessFlags).toCommonFlags) + .withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase])) val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match { case meth: DefDef => (meth, Nil) @@ -288,17 +291,22 @@ object desugar { } val isCaseClass = mods.is(Case) && !mods.is(Module) + val isEnum = mods.hasMod[Mod.Enum] + val isEnumCase = isLegalEnumCase(cdef) val isValueClass = parents.nonEmpty && isAnyVal(parents.head) // This is not watertight, but `extends AnyVal` will be replaced by `inline` later. - val constrTparams = constr1.tparams map toDefParam + val originalTparams = + if (isEnumCase && parents.isEmpty) reconstitutedEnumTypeParams(cdef.pos.startPos) + else constr1.tparams + val originalVparamss = constr1.vparamss + val constrTparams = originalTparams.map(toDefParam) val constrVparamss = - if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty - if (isCaseClass) - ctx.error(CaseClassMissingParamList(cdef), cdef.namePos) + if (originalVparamss.isEmpty) { // ensure parameter list is non-empty + if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos) ListOfNil } - else constr1.vparamss.nestedMap(toDefParam) + else originalVparamss.nestedMap(toDefParam) val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss) // Add constructor type parameters and evidence implicit parameters @@ -312,21 +320,22 @@ object desugar { stat } - val derivedTparams = constrTparams map derivedTypeParam + val derivedTparams = + if (isEnumCase) constrTparams else constrTparams map derivedTypeParam val derivedVparamss = constrVparamss nestedMap derivedTermParam val arity = constrVparamss.head.length - var classTycon: Tree = EmptyTree + val classTycon: Tree = new TypeRefTree // watching is set at end of method - // a reference to the class type, with all parameters given. - val classTypeRef/*: Tree*/ = { - // -language:keepUnions difference: classTypeRef needs type annotation, otherwise - // infers Ident | AppliedTypeTree, which - // renders the :\ in companions below untypable. - classTycon = (new TypeRefTree) withPos cdef.pos.startPos // watching is set at end of method - val tparams = impl.constr.tparams - if (tparams.isEmpty) classTycon else AppliedTypeTree(classTycon, tparams map refOfDef) - } + def appliedRef(tycon: Tree) = + (if (constrTparams.isEmpty) tycon + else AppliedTypeTree(tycon, constrTparams map refOfDef)) + .withPos(cdef.pos.startPos) + + // a reference to the class type bound by `cdef`, with type parameters coming from the constructor + val classTypeRef = appliedRef(classTycon) + // a refereence to `enumClass`, with type parameters coming from the constructor + lazy val enumClassTypeRef = appliedRef(enumClassRef) // new C[Ts](paramss) lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef) @@ -374,7 +383,9 @@ object desugar { DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr) .withMods(synthetic) :: Nil } - copyMeths ::: productElemMeths.toList + + val enumTagMeths = if (isEnumCase) enumTagMeth :: Nil else Nil + copyMeths ::: enumTagMeths ::: productElemMeths.toList } else Nil @@ -387,8 +398,12 @@ object desugar { // Case classes and case objects get a ProductN parent var parents1 = parents + if (isEnumCase && parents.isEmpty) + parents1 = enumClassTypeRef :: Nil if (mods.is(Case) && arity <= Definitions.MaxTupleArity) - parents1 = parents1 :+ productConstr(arity) + parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that? + if (isEnum) + parents1 = parents1 :+ ref(defn.EnumType) // The thicket which is the desugared version of the companion object // synthetic object C extends parentTpt { defs } @@ -419,9 +434,11 @@ object desugar { else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function(vparams map (_.tpt), restpe)) val applyMeths = if (mods is Abstract) Nil - else - DefDef(nme.apply, derivedTparams, derivedVparamss, TypeTree(), creatorExpr) + else { + val restpe = if (isEnumCase) enumClassTypeRef else TypeTree() + DefDef(nme.apply, derivedTparams, derivedVparamss, restpe, creatorExpr) .withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil + } val unapplyMeth = { val unapplyParam = makeSyntheticParameter(tpt = classTypeRef) val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name) @@ -464,15 +481,15 @@ object desugar { else cpy.ValDef(self)(tpt = selfType).withMods(self.mods | SelfName) } - val cdef1 = { - val originalTparams = constr1.tparams.toIterator - val originalVparams = constr1.vparamss.toIterator.flatten - val tparamAccessors = derivedTparams.map(_.withMods(originalTparams.next.mods)) + val cdef1 = addEnumFlags { + val originalTparamsIt = originalTparams.toIterator + val originalVparamsIt = originalVparamss.toIterator.flatten + val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next.mods)) val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags val vparamAccessors = derivedVparamss match { case first :: rest => - first.map(_.withMods(originalVparams.next.mods | caseAccessor)) ++ - rest.flatten.map(_.withMods(originalVparams.next.mods)) + first.map(_.withMods(originalVparamsIt.next.mods | caseAccessor)) ++ + rest.flatten.map(_.withMods(originalVparamsIt.next.mods)) case _ => Nil } @@ -503,23 +520,26 @@ object desugar { */ def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = { val moduleName = checkNotReservedName(mdef).asTermName - val tmpl = mdef.impl + val impl = mdef.impl val mods = mdef.mods + lazy val isEnumCase = isLegalEnumCase(mdef) if (mods is Package) - PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, tmpl).withMods(mods &~ Package) :: Nil) + PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil) + else if (isEnumCase) + expandEnumModule(moduleName, impl, mods, mdef.pos) else { val clsName = moduleName.moduleClassName val clsRef = Ident(clsName) val modul = ValDef(moduleName, clsRef, New(clsRef, Nil)) .withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags) .withPos(mdef.pos) - val ValDef(selfName, selfTpt, _) = tmpl.self - val selfMods = tmpl.self.mods - if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), tmpl.self.pos) - val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), tmpl.self.rhs) + val ValDef(selfName, selfTpt, _) = impl.self + val selfMods = impl.self.mods + if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), impl.self.pos) + val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), impl.self.rhs) .withMods(selfMods) - .withPos(tmpl.self.pos orElse tmpl.pos.startPos) - val clsTmpl = cpy.Template(tmpl)(self = clsSelf, body = tmpl.body) + .withPos(impl.self.pos orElse impl.pos.startPos) + val clsTmpl = cpy.Template(impl)(self = clsSelf, body = impl.body) val cls = TypeDef(clsName, clsTmpl) .withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags) Thicket(modul, classDef(cls).withPos(mdef.pos)) diff --git a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala new file mode 100644 index 000000000..4317c8183 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala @@ -0,0 +1,124 @@ +package dotty.tools +package dotc +package ast + +import core._ +import util.Positions._, Types._, Contexts._, Constants._, Names._, NameOps._, Flags._ +import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._ +import Decorators._ +import collection.mutable.ListBuffer +import util.Property +import reporting.diagnostic.messages._ + +object DesugarEnums { + import untpd._ + import desugar.DerivedFromParamTree + + val EnumCaseCount = new Property.Key[Int] + + def enumClass(implicit ctx: Context) = ctx.owner.linkedClass + + def nextEnumTag(implicit ctx: Context): Int = { + val result = ctx.tree.removeAttachment(EnumCaseCount).getOrElse(0) + ctx.tree.pushAttachment(EnumCaseCount, result + 1) + result + } + + def isLegalEnumCase(tree: MemberDef)(implicit ctx: Context): Boolean = { + tree.mods.hasMod[Mod.EnumCase] && + ( ctx.owner.is(ModuleClass) && enumClass.derivesFrom(defn.EnumClass) + || { ctx.error(em"case not allowed here, since owner ${ctx.owner} is not an `enum' object", tree.pos) + false + } + ) + } + + /** Type parameters reconstituted from the constructor + * of the `enum' class corresponding to an enum case + */ + def reconstitutedEnumTypeParams(pos: Position)(implicit ctx: Context) = { + val tparams = enumClass.primaryConstructor.info match { + case info: PolyType => + ctx.newTypeParams(ctx.newLocalDummy(enumClass), info.paramNames, EmptyFlags, info.instantiateBounds) + case _ => + Nil + } + for (tparam <- tparams) yield { + val tbounds = new DerivedFromParamTree + tbounds.pushAttachment(OriginalSymbol, tparam) + TypeDef(tparam.name, tbounds) + .withFlags(Param | PrivateLocal).withPos(pos) + } + } + + def enumTagMeth(implicit ctx: Context) = + DefDef(nme.enumTag, Nil, Nil, TypeTree(), Literal(Constant(nextEnumTag))) + + def enumClassRef(implicit ctx: Context) = TypeTree(enumClass.typeRef) + + def addEnumFlags(cdef: TypeDef)(implicit ctx: Context) = + if (cdef.mods.hasMod[Mod.Enum]) cdef.withFlags(cdef.mods.flags | Abstract | Sealed) + else if (isLegalEnumCase(cdef)) cdef.withFlags(cdef.mods.flags | Final) + else cdef + + /** The following lists of definitions for an enum type E: + * + * private val $values = new EnumValues[E] + * def valueOf: Int => E = $values + * def values = $values.values + * + * private def $new(tag: Int, name: String) = new E { + * def enumTag = tag + * override def toString = name + * $values.register(this) + * } + */ + private def enumScaffolding(implicit ctx: Context): List[Tree] = { + val valsRef = Ident(nme.DOLLAR_VALUES) + def param(name: TermName, typ: Type) = + ValDef(name, TypeTree(typ), EmptyTree).withFlags(Param) + val privateValuesDef = + ValDef(nme.DOLLAR_VALUES, TypeTree(), + New(TypeTree(defn.EnumValuesType.appliedTo(enumClass.typeRef :: Nil)), ListOfNil)) + .withFlags(Private) + val valueOfDef = + DefDef(nme.valueOf, Nil, Nil, + TypeTree(defn.FunctionOf(defn.IntType :: Nil, enumClass.typeRef)), valsRef) + val valuesDef = + DefDef(nme.values, Nil, Nil, TypeTree(), Select(valsRef, nme.values)) + val enumTagDef = + DefDef(nme.enumTag, Nil, Nil, TypeTree(), Ident(nme.tag)) + val toStringDef = + DefDef(nme.toString_, Nil, Nil, TypeTree(), Ident(nme.name)) + .withFlags(Override) + val registerStat = + Apply(Select(valsRef, nme.register), This(EmptyTypeIdent) :: Nil) + def creator = New(Template(emptyConstructor, enumClassRef :: Nil, EmptyValDef, + List(enumTagDef, toStringDef, registerStat))) + val newDef = + DefDef(nme.DOLLAR_NEW, Nil, + List(List(param(nme.tag, defn.IntType), param(nme.name, defn.StringType))), + TypeTree(), creator) + List(privateValuesDef, valueOfDef, valuesDef, newDef) + } + + def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree = { + def nameLit = Literal(Constant(name.toString)) + if (impl.parents.isEmpty) { + if (reconstitutedEnumTypeParams(pos).nonEmpty) + ctx.error(i"illegal enum value of generic $enumClass: an explicit `extends' clause is needed", pos) + val tag = nextEnumTag + val prefix = if (tag == 0) enumScaffolding else Nil + val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), nameLit)) + val vdef = ValDef(name, enumClassRef, creator).withMods(mods | Final).withPos(pos) + flatTree(prefix ::: vdef :: Nil).withPos(pos.startPos) + } else { + def toStringMeth = + DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), nameLit) + .withFlags(Override) + val impl1 = cpy.Template(impl)(body = + impl.body ++ List(enumTagMeth, toStringMeth)) + ValDef(name, TypeTree(), New(impl1)).withMods(mods | Final).withPos(pos) + } + } +} diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 4d4350f98..39b46cbfe 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -512,6 +512,10 @@ class Definitions { def DynamicClass(implicit ctx: Context) = DynamicType.symbol.asClass lazy val OptionType: TypeRef = ctx.requiredClassRef("scala.Option") def OptionClass(implicit ctx: Context) = OptionType.symbol.asClass + lazy val EnumType: TypeRef = ctx.requiredClassRef("scala.Enum") + def EnumClass(implicit ctx: Context) = EnumType.symbol.asClass + lazy val EnumValuesType: TypeRef = ctx.requiredClassRef("scala.runtime.EnumValues") + def EnumValuesClass(implicit ctx: Context) = EnumValuesType.symbol.asClass lazy val ProductType: TypeRef = ctx.requiredClassRef("scala.Product") def ProductClass(implicit ctx: Context) = ProductType.symbol.asClass lazy val Product_canEqualR = ProductClass.requiredMethodRef(nme.canEqual_) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 5b7dc3d1d..ff3ddbad7 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -132,6 +132,8 @@ object StdNames { val TRAIT_SETTER_SEPARATOR: N = "$_setter_$" val DIRECT_SUFFIX: N = "$direct" val LAZY_IMPLICIT_PREFIX: N = "$lazy_implicit$" + val DOLLAR_VALUES: N = "$values" + val DOLLAR_NEW: N = "$new" // value types (and AnyRef) are all used as terms as well // as (at least) arguments to the @specialize annotation. @@ -395,6 +397,7 @@ object StdNames { val elem: N = "elem" val emptyValDef: N = "emptyValDef" val ensureAccessible : N = "ensureAccessible" + val enumTag: N = "enumTag" val eq: N = "eq" val equalsNumChar : N = "equalsNumChar" val equalsNumNum : N = "equalsNumNum" @@ -474,6 +477,7 @@ object StdNames { val productPrefix: N = "productPrefix" val readResolve: N = "readResolve" val reflect : N = "reflect" + val register: N = "register" val reify : N = "reify" val rootMirror : N = "rootMirror" val runOrElse: N = "runOrElse" @@ -499,6 +503,7 @@ object StdNames { val staticModule : N = "staticModule" val staticPackage : N = "staticPackage" val synchronized_ : N = "synchronized" + val tag: N = "tag" val tail: N = "tail" val `then` : N = "then" val this_ : N = "this" @@ -523,7 +528,7 @@ object StdNames { val updateDynamic: N = "updateDynamic" val value: N = "value" val valueOf : N = "valueOf" - val values : N = "values" + val values: N = "values" val view_ : N = "view" val wait_ : N = "wait" val withFilter: N = "withFilter" diff --git a/library/src/scala/Enum.scala b/library/src/scala/Enum.scala new file mode 100644 index 000000000..7d2eefb3d --- /dev/null +++ b/library/src/scala/Enum.scala @@ -0,0 +1,8 @@ +package scala + +/** A base trait of all enum classes */ +trait Enum { + + /** A number uniquely identifying a case of an enum */ + def enumTag: Int +} diff --git a/library/src/scala/runtime/EnumValues.scala b/library/src/scala/runtime/EnumValues.scala new file mode 100644 index 000000000..6d2e56cf3 --- /dev/null +++ b/library/src/scala/runtime/EnumValues.scala @@ -0,0 +1,18 @@ +package scala.runtime + +import scala.collection.immutable.Seq +import scala.collection.mutable.ResizableArray + +class EnumValues[E <: Enum] extends ResizableArray[E] { + private var valuesCache: List[E] = Nil + def register(v: E) = { + ensureSize(v.enumTag + 1) + size0 = size0 max (v.enumTag + 1) + array(v.enumTag) = v + valuesCache = null + } + def values: Seq[E] = { + if (valuesCache == null) valuesCache = array.filter(_ != null).toList.asInstanceOf[scala.List[E]] + valuesCache + } +} |