diff options
Diffstat (limited to 'examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer')
7 files changed, 6660 insertions, 0 deletions
diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/Analyzer.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/Analyzer.scala new file mode 100644 index 0000000..9cdd764 --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/Analyzer.scala @@ -0,0 +1,587 @@ +/* __ *\ +** ________ ___ / / ___ __ ____ Scala.js tools ** +** / __/ __// _ | / / / _ | __ / // __/ (c) 2013-2014, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ |/_// /_\ \ http://scala-js.org/ ** +** /____/\___/_/ |_/____/_/ | |__/ /____/ ** +** |/____/ ** +\* */ + + +package scala.scalajs.tools.optimizer + +import scala.annotation.tailrec + +import scala.collection.mutable + +import scala.scalajs.ir +import ir.{ClassKind, Definitions, Infos} + +import scala.scalajs.tools.sem._ +import scala.scalajs.tools.javascript.LongImpl +import scala.scalajs.tools.logging._ + +import ScalaJSOptimizer._ + +class Analyzer(logger0: Logger, semantics: Semantics, + allData: Seq[Infos.ClassInfo], globalWarnEnabled: Boolean, + isBeforeOptimizer: Boolean) { + /* Set this to true to debug the DCE analyzer. + * We don't rely on config to disable 'debug' messages because we want + * to use 'debug' for displaying more stack trace info that the user can + * see with the 'last' command. + */ + val DebugAnalyzer = false + + object logger extends Logger { + var indentation: String = "" + + def indent(): Unit = indentation += " " + def undent(): Unit = indentation = indentation.substring(2) + + def log(level: Level, message: => String) = + logger0.log(level, indentation+message) + def success(message: => String) = + logger0.success(indentation+message) + def trace(t: => Throwable) = + logger0.trace(t) + + def indented[A](body: => A): A = { + indent() + try body + finally undent() + } + + def debugIndent[A](message: => String)(body: => A): A = { + if (DebugAnalyzer) { + debug(message) + indented(body) + } else { + body + } + } + + def temporarilyNotIndented[A](body: => A): A = { + val savedIndent = indentation + indentation = "" + try body + finally indentation = savedIndent + } + } + + sealed trait From + case class FromMethod(methodInfo: MethodInfo) extends From + case object FromCore extends From + case object FromExports extends From + case object FromManual extends From + + var allAvailable: Boolean = true + + val classInfos: mutable.Map[String, ClassInfo] = { + val cs = for (classData <- allData) + yield (classData.encodedName, new ClassInfo(classData)) + mutable.Map.empty[String, ClassInfo] ++ cs + } + + def lookupClass(encodedName: String): ClassInfo = { + classInfos.get(encodedName) match { + case Some(info) => info + case None => + val c = new ClassInfo(createMissingClassInfo(encodedName)) + classInfos += encodedName -> c + c.nonExistent = true + c.linkClasses() + c + } + } + + def lookupModule(encodedName: String): ClassInfo = { + lookupClass(encodedName+"$") + } + + linkClasses() + + def linkClasses(): Unit = { + if (!classInfos.contains(ir.Definitions.ObjectClass)) + sys.error("Fatal error: could not find java.lang.Object on the classpath") + for (classInfo <- classInfos.values.toList) + classInfo.linkClasses() + } + + def computeReachability(manuallyReachable: Seq[ManualReachability], + noWarnMissing: Seq[NoWarnMissing]): Unit = { + // Stuff reachable from core symbols always should warn + reachCoreSymbols() + + // Disable warnings as requested + noWarnMissing.foreach(disableWarning _) + + // Reach all user stuff + manuallyReachable.foreach(reachManually _) + for (classInfo <- classInfos.values) + classInfo.reachExports() + } + + /** Reach symbols used directly by scalajsenv.js. */ + def reachCoreSymbols(): Unit = { + import semantics._ + import CheckedBehavior._ + + implicit val from = FromCore + + def instantiateClassWith(className: String, constructor: String): ClassInfo = { + val info = lookupClass(className) + info.instantiated() + info.callMethod(constructor) + info + } + + val ObjectClass = instantiateClassWith("O", "init___") + ObjectClass.callMethod("toString__T") + ObjectClass.callMethod("equals__O__Z") + + instantiateClassWith("jl_NullPointerException", "init___") + + if (asInstanceOfs != Unchecked) + instantiateClassWith("jl_ClassCastException", "init___T") + + if (asInstanceOfs == Fatal) + instantiateClassWith("sjsr_UndefinedBehaviorError", "init___jl_Throwable") + + instantiateClassWith("jl_Class", "init___jl_ScalaJSClassData") + + val RTStringModuleClass = lookupClass("sjsr_RuntimeString$") + RTStringModuleClass.accessModule() + RTStringModuleClass.callMethod("hashCode__T__I") + + val RTLongClass = lookupClass(LongImpl.RuntimeLongClass) + RTLongClass.instantiated() + for (method <- LongImpl.AllConstructors ++ LongImpl.AllMethods) + RTLongClass.callMethod(method) + + if (isBeforeOptimizer) { + for (method <- LongImpl.AllIntrinsicMethods) + RTLongClass.callMethod(method) + } + + val RTLongModuleClass = lookupClass(LongImpl.RuntimeLongModuleClass) + RTLongModuleClass.accessModule() + for (method <- LongImpl.AllModuleMethods) + RTLongModuleClass.callMethod(method) + + if (isBeforeOptimizer) { + for (hijacked <- Definitions.HijackedClasses) + lookupClass(hijacked).instantiated() + } else { + for (hijacked <- Definitions.HijackedClasses) + lookupClass(hijacked).accessData() + } + + if (semantics.strictFloats) { + val RuntimePackage = lookupClass("sjsr_package$") + RuntimePackage.accessModule() + RuntimePackage.callMethod("froundPolyfill__D__D") + } + + val BitsModuleClass = lookupClass("sjsr_Bits$") + BitsModuleClass.accessModule() + BitsModuleClass.callMethod("numberHashCode__D__I") + } + + def reachManually(info: ManualReachability) = { + implicit val from = FromManual + + // Don't lookupClass here, since we don't want to create any + // symbols. If a symbol doesn't exist, we fail. + info match { + case ReachObject(name) => classInfos(name + "$").accessModule() + case Instantiate(name) => classInfos(name).instantiated() + case ReachMethod(className, methodName, static) => + classInfos(className).callMethod(methodName, static) + } + } + + def disableWarning(noWarn: NoWarnMissing) = noWarn match { + case NoWarnClass(className) => + lookupClass(className).warnEnabled = false + case NoWarnMethod(className, methodName) => + lookupClass(className).lookupMethod(methodName).warnEnabled = false + } + + class ClassInfo(data: Infos.ClassInfo) { + val encodedName = data.encodedName + val ancestorCount = data.ancestorCount + val isStaticModule = data.kind == ClassKind.ModuleClass + val isInterface = data.kind == ClassKind.Interface + val isImplClass = data.kind == ClassKind.TraitImpl + val isRawJSType = data.kind == ClassKind.RawJSType + val isHijackedClass = data.kind == ClassKind.HijackedClass + val isClass = !isInterface && !isImplClass && !isRawJSType + val isExported = data.isExported + + val hasData = !isImplClass + val hasMoreThanData = isClass && !isHijackedClass + + var superClass: ClassInfo = _ + val ancestors = mutable.ListBuffer.empty[ClassInfo] + val descendants = mutable.ListBuffer.empty[ClassInfo] + + var nonExistent: Boolean = false + var warnEnabled: Boolean = true + + def linkClasses(): Unit = { + if (data.superClass != "") + superClass = lookupClass(data.superClass) + ancestors ++= data.ancestors.map(lookupClass) + for (ancestor <- ancestors) + ancestor.descendants += this + } + + lazy val descendentClasses = descendants.filter(_.isClass) + + def optimizerHints: Infos.OptimizerHints = data.optimizerHints + + var isInstantiated: Boolean = false + var isAnySubclassInstantiated: Boolean = false + var isModuleAccessed: Boolean = false + var isDataAccessed: Boolean = false + + var instantiatedFrom: Option[From] = None + + val delayedCalls = mutable.Map.empty[String, From] + + def isNeededAtAll = + isDataAccessed || + isAnySubclassInstantiated || + (isImplClass && methodInfos.values.exists(_.isReachable)) + + lazy val methodInfos: mutable.Map[String, MethodInfo] = { + val ms = for (methodData <- data.methods) + yield (methodData.encodedName, new MethodInfo(this, methodData)) + mutable.Map.empty[String, MethodInfo] ++ ms + } + + def lookupMethod(methodName: String): MethodInfo = { + tryLookupMethod(methodName).getOrElse { + val syntheticData = createMissingMethodInfo(methodName) + val m = new MethodInfo(this, syntheticData) + m.nonExistent = true + methodInfos += methodName -> m + m + } + } + + def tryLookupMethod(methodName: String): Option[MethodInfo] = { + assert(isClass || isImplClass, + s"Cannot call lookupMethod($methodName) on non-class $this") + @tailrec + def loop(ancestorInfo: ClassInfo): Option[MethodInfo] = { + if (ancestorInfo ne null) { + ancestorInfo.methodInfos.get(methodName) match { + case Some(m) if !m.isAbstract => Some(m) + case _ => loop(ancestorInfo.superClass) + } + } else { + None + } + } + loop(this) + } + + override def toString(): String = encodedName + + /** Start reachability algorithm with the exports for that class. */ + def reachExports(): Unit = { + implicit val from = FromExports + + // Myself + if (isExported) { + assert(!isImplClass, "An implementation class must not be exported") + if (isStaticModule) accessModule() + else instantiated() + } + + // My methods + for (methodInfo <- methodInfos.values) { + if (methodInfo.isExported) + callMethod(methodInfo.encodedName) + } + } + + def accessModule()(implicit from: From): Unit = { + assert(isStaticModule, s"Cannot call accessModule() on non-module $this") + if (!isModuleAccessed) { + logger.debugIndent(s"$this.isModuleAccessed = true") { + isModuleAccessed = true + instantiated() + callMethod("init___") + } + } + } + + def instantiated()(implicit from: From): Unit = { + if (!isInstantiated && isClass) { + logger.debugIndent(s"$this.isInstantiated = true") { + isInstantiated = true + instantiatedFrom = Some(from) + ancestors.foreach(_.subclassInstantiated()) + } + + for ((methodName, from) <- delayedCalls) + delayedCallMethod(methodName)(from) + } + } + + private def subclassInstantiated()(implicit from: From): Unit = { + if (!isAnySubclassInstantiated && isClass) { + logger.debugIndent(s"$this.isAnySubclassInstantiated = true") { + isAnySubclassInstantiated = true + if (instantiatedFrom.isEmpty) + instantiatedFrom = Some(from) + accessData() + methodInfos.get("__init__").foreach(_.reachStatic()) + } + } + } + + def accessData()(implicit from: From): Unit = { + if (!isDataAccessed && hasData) { + checkExistent() + if (DebugAnalyzer) + logger.debug(s"$this.isDataAccessed = true") + isDataAccessed = true + } + } + + def checkExistent()(implicit from: From): Unit = { + if (nonExistent) { + if (warnEnabled && globalWarnEnabled) { + logger.warn(s"Referring to non-existent class $encodedName") + warnCallStack() + } + nonExistent = false + allAvailable = false + } + } + + def callMethod(methodName: String, static: Boolean = false)( + implicit from: From): Unit = { + logger.debugIndent(s"calling${if (static) " static" else ""} $this.$methodName") { + if (isImplClass) { + // methods in impl classes are always implicitly called statically + lookupMethod(methodName).reachStatic() + } else if (isConstructorName(methodName)) { + // constructors are always implicitly called statically + lookupMethod(methodName).reachStatic() + } else if (static) { + assert(!isReflProxyName(methodName), + s"Trying to call statically refl proxy $this.$methodName") + lookupMethod(methodName).reachStatic() + } else { + for (descendentClass <- descendentClasses) { + if (descendentClass.isInstantiated) + descendentClass.delayedCallMethod(methodName) + else + descendentClass.delayedCalls += ((methodName, from)) + } + } + } + } + + private def delayedCallMethod(methodName: String)(implicit from: From): Unit = { + if (isReflProxyName(methodName)) { + tryLookupMethod(methodName).foreach(_.reach(this)) + } else { + lookupMethod(methodName).reach(this) + } + } + } + + class MethodInfo(val owner: ClassInfo, data: Infos.MethodInfo) { + + val encodedName = data.encodedName + val isAbstract = data.isAbstract + val isExported = data.isExported + val isReflProxy = isReflProxyName(encodedName) + + def optimizerHints: Infos.OptimizerHints = data.optimizerHints + + var isReachable: Boolean = false + + var calledFrom: Option[From] = None + var instantiatedSubclass: Option[ClassInfo] = None + + var nonExistent: Boolean = false + var warnEnabled: Boolean = true + + override def toString(): String = s"$owner.$encodedName" + + def reachStatic()(implicit from: From): Unit = { + assert(!isAbstract, + s"Trying to reach statically the abstract method $this") + + checkExistent() + + if (!isReachable) { + logger.debugIndent(s"$this.isReachable = true") { + isReachable = true + calledFrom = Some(from) + doReach() + } + } + } + + def reach(inClass: ClassInfo)(implicit from: From): Unit = { + assert(owner.isClass, + s"Trying to reach dynamically the non-class method $this") + assert(!isConstructorName(encodedName), + s"Trying to reach dynamically the constructor $this") + + checkExistent() + + if (!isReachable) { + logger.debugIndent(s"$this.isReachable = true") { + isReachable = true + calledFrom = Some(from) + instantiatedSubclass = Some(inClass) + doReach() + } + } + } + + private def checkExistent()(implicit from: From) = { + if (nonExistent) { + if (warnEnabled && owner.warnEnabled && globalWarnEnabled) { + logger.temporarilyNotIndented { + logger.warn(s"Referring to non-existent method $this") + warnCallStack() + } + } + allAvailable = false + } + } + + private[this] def doReach(): Unit = { + logger.debugIndent(s"$this.doReach()") { + implicit val from = FromMethod(this) + + if (owner.isImplClass) + owner.checkExistent() + + for (moduleName <- data.accessedModules) { + lookupModule(moduleName).accessModule() + } + + for (className <- data.instantiatedClasses) { + lookupClass(className).instantiated() + } + + for (className <- data.accessedClassData) { + lookupClass(className).accessData() + } + + for ((className, methods) <- data.calledMethods) { + val classInfo = lookupClass(className) + for (methodName <- methods) + classInfo.callMethod(methodName) + } + + for ((className, methods) <- data.calledMethodsStatic) { + val classInfo = lookupClass(className) + for (methodName <- methods) + classInfo.callMethod(methodName, static = true) + } + } + } + } + + def isReflProxyName(encodedName: String): Boolean = { + encodedName.endsWith("__") && + (encodedName != "init___") && (encodedName != "__init__") + } + + def isConstructorName(encodedName: String): Boolean = + encodedName.startsWith("init___") || (encodedName == "__init__") + + private def createMissingClassInfo(encodedName: String): Infos.ClassInfo = { + val kind = + if (encodedName.endsWith("$")) ClassKind.ModuleClass + else if (encodedName.endsWith("$class")) ClassKind.TraitImpl + else ClassKind.Class + Infos.ClassInfo( + name = s"<$encodedName>", + encodedName = encodedName, + isExported = false, + ancestorCount = if (kind.isClass) 1 else 0, + kind = kind, + superClass = if (kind.isClass) "O" else "", + ancestors = List(encodedName, "O"), + methods = List( + createMissingMethodInfo("__init__"), + createMissingMethodInfo("init___")) + ) + } + + private def createMissingMethodInfo(encodedName: String, + isAbstract: Boolean = false): Infos.MethodInfo = { + Infos.MethodInfo(encodedName = encodedName, isAbstract = isAbstract) + } + + def warnCallStack()(implicit from: From): Unit = { + val seenInfos = mutable.Set.empty[AnyRef] + + def rec(level: Level, optFrom: Option[From], + verb: String = "called"): Unit = { + val involvedClasses = new mutable.ListBuffer[ClassInfo] + + def onlyOnce(info: AnyRef): Boolean = { + if (seenInfos.add(info)) { + true + } else { + logger.log(level, " (already seen, not repeating call stack)") + false + } + } + + @tailrec + def loopTrace(optFrom: Option[From], verb: String = "called"): Unit = { + optFrom match { + case None => + logger.log(level, s"$verb from ... er ... nowhere!? (this is a bug in dce)") + case Some(from) => + from match { + case FromMethod(methodInfo) => + logger.log(level, s"$verb from $methodInfo") + if (onlyOnce(methodInfo)) { + methodInfo.instantiatedSubclass.foreach(involvedClasses += _) + loopTrace(methodInfo.calledFrom) + } + case FromCore => + logger.log(level, s"$verb from scalajs-corejslib.js") + case FromExports => + logger.log(level, "exported to JavaScript with @JSExport") + case FromManual => + logger.log(level, "manually made reachable") + } + } + } + + logger.indented { + loopTrace(optFrom, verb = verb) + } + + if (involvedClasses.nonEmpty) { + logger.log(level, "involving instantiated classes:") + logger.indented { + for (classInfo <- involvedClasses.result().distinct) { + logger.log(level, s"$classInfo") + if (onlyOnce(classInfo)) + rec(Level.Debug, classInfo.instantiatedFrom, verb = "instantiated") + // recurse with Debug log level not to overwhelm the user + } + } + } + } + + rec(Level.Warn, Some(from)) + } +} diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/GenIncOptimizer.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/GenIncOptimizer.scala new file mode 100644 index 0000000..47e1f87 --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/GenIncOptimizer.scala @@ -0,0 +1,921 @@ +/* __ *\ +** ________ ___ / / ___ __ ____ Scala.js tools ** +** / __/ __// _ | / / / _ | __ / // __/ (c) 2013-2014, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ |/_// /_\ \ http://scala-js.org/ ** +** /____/\___/_/ |_/____/_/ | |__/ /____/ ** +** |/____/ ** +\* */ + + +package scala.scalajs.tools.optimizer + +import language.higherKinds + +import scala.annotation.{switch, tailrec} + +import scala.collection.{GenMap, GenTraversableOnce, GenIterable, GenIterableLike} +import scala.collection.mutable + +import scala.scalajs.ir._ +import Definitions.isConstructorName +import Infos.OptimizerHints +import Trees._ +import Types._ + +import scala.scalajs.tools.sem._ + +import scala.scalajs.tools.javascript +import javascript.Trees.{Tree => JSTree} +import javascript.ScalaJSClassEmitter + +import scala.scalajs.tools.logging._ + +/** Incremental optimizer. + * An incremental optimizer consumes the reachability analysis produced by + * an [[Analyzer]], as well as trees for classes, trait impls, etc., and + * optimizes them in an incremental way. + * It maintains state between runs to do a minimal amount of work on every + * run, based on detecting what parts of the program must be re-optimized, + * and keeping optimized results from previous runs for the rest. + */ +abstract class GenIncOptimizer(semantics: Semantics) { + import GenIncOptimizer._ + + protected val CollOps: AbsCollOps + + private val classEmitter = new ScalaJSClassEmitter(semantics) + + private var logger: Logger = _ + + /** Are we in batch mode? I.e., are we running from scratch? + * Various parts of the algorithm can be skipped entirely when running in + * batch mode. + */ + private var batchMode: Boolean = false + + /** Should positions be considered when comparing tree hashes */ + private var considerPositions: Boolean = _ + + private var objectClass: Class = _ + private val classes = CollOps.emptyMap[String, Class] + private val traitImpls = CollOps.emptyParMap[String, TraitImpl] + + protected def getInterface(encodedName: String): InterfaceType + + /** Schedule a method for processing in the PROCESS PASS */ + protected def scheduleMethod(method: MethodImpl): Unit + + protected def newMethodImpl(owner: MethodContainer, + encodedName: String): MethodImpl + + def findTraitImpl(encodedName: String): TraitImpl = traitImpls(encodedName) + def findClass(encodedName: String): Class = classes(encodedName) + + def getTraitImpl(encodedName: String): Option[TraitImpl] = traitImpls.get(encodedName) + def getClass(encodedName: String): Option[Class] = classes.get(encodedName) + + type GetClassTreeIfChanged = + (String, Option[String]) => Option[(ClassDef, Option[String])] + + private def withLogger[A](logger: Logger)(body: => A): A = { + assert(this.logger == null) + this.logger = logger + try body + finally this.logger = null + } + + /** Update the incremental analyzer with a new run. */ + def update(analyzer: Analyzer, + getClassTreeIfChanged: GetClassTreeIfChanged, considerPositions: Boolean, + logger: Logger): Unit = withLogger(logger) { + + batchMode = objectClass == null + this.considerPositions = considerPositions + logger.debug(s"Optimizer batch mode: $batchMode") + + logTime(logger, "Incremental part of inc. optimizer") { + /* UPDATE PASS */ + updateAndTagEverything(analyzer, getClassTreeIfChanged) + } + + logTime(logger, "Optimizer part of inc. optimizer") { + /* PROCESS PASS */ + processAllTaggedMethods() + } + } + + /** Incremental part: update state and detect what needs to be re-optimized. + * UPDATE PASS ONLY. (This IS the update pass). + */ + private def updateAndTagEverything(analyzer: Analyzer, + getClassTreeIfChanged: GetClassTreeIfChanged): Unit = { + + val neededClasses = CollOps.emptyParMap[String, analyzer.ClassInfo] + val neededTraitImpls = CollOps.emptyParMap[String, analyzer.ClassInfo] + for { + classInfo <- analyzer.classInfos.values + if classInfo.isNeededAtAll + } { + if (classInfo.isClass && classInfo.isAnySubclassInstantiated) + CollOps.put(neededClasses, classInfo.encodedName, classInfo) + else if (classInfo.isImplClass) + CollOps.put(neededTraitImpls, classInfo.encodedName, classInfo) + } + + /* Remove deleted trait impls, and update existing trait impls. + * We don't even have to notify callers in case of additions or removals + * because callers have got to be invalidated by themselves. + * Only changed methods need to trigger notifications. + * + * Non-batch mode only. + */ + assert(!batchMode || traitImpls.isEmpty) + if (!batchMode) { + CollOps.retain(traitImpls) { (traitImplName, traitImpl) => + CollOps.remove(neededTraitImpls, traitImplName).fold { + /* Deleted trait impl. Mark all its methods as deleted, and remove it + * from known trait impls. + */ + traitImpl.methods.values.foreach(_.delete()) + + false + } { traitImplInfo => + /* Existing trait impl. Update it. */ + val (added, changed, removed) = + traitImpl.updateWith(traitImplInfo, getClassTreeIfChanged) + for (method <- changed) + traitImpl.myInterface.tagStaticCallersOf(method) + + true + } + } + } + + /* Add new trait impls. + * Easy, we don't have to notify anyone. + */ + for (traitImplInfo <- neededTraitImpls.values) { + val traitImpl = new TraitImpl(traitImplInfo.encodedName) + CollOps.put(traitImpls, traitImpl.encodedName, traitImpl) + traitImpl.updateWith(traitImplInfo, getClassTreeIfChanged) + } + + if (!batchMode) { + /* Class removals: + * * If a class is deleted or moved, delete its entire subtree (because + * all its descendants must also be deleted or moved). + * * If an existing class was instantiated but is no more, notify callers + * of its methods. + * + * Non-batch mode only. + */ + val objectClassStillExists = + objectClass.walkClassesForDeletions(neededClasses.get(_)) + assert(objectClassStillExists, "Uh oh, java.lang.Object was deleted!") + + /* Class changes: + * * Delete removed methods, update existing ones, add new ones + * * Update the list of ancestors + * * Class newly instantiated + * + * Non-batch mode only. + */ + objectClass.walkForChanges( + CollOps.remove(neededClasses, _).get, + getClassTreeIfChanged, + Set.empty) + } + + /* Class additions: + * * Add new classes (including those that have moved from elsewhere). + * In batch mode, we avoid doing notifications. + */ + + // Group children by (immediate) parent + val newChildrenByParent = CollOps.emptyAccMap[String, Analyzer#ClassInfo] + + for (classInfo <- neededClasses.values) { + val superInfo = classInfo.superClass + if (superInfo == null) { + assert(batchMode, "Trying to add java.lang.Object in incremental mode") + objectClass = new Class(None, classInfo.encodedName) + classes += classInfo.encodedName -> objectClass + objectClass.setupAfterCreation(classInfo, getClassTreeIfChanged) + } else { + CollOps.acc(newChildrenByParent, superInfo.encodedName, classInfo) + } + } + + val getNewChildren = + (name: String) => CollOps.getAcc(newChildrenByParent, name) + + // Walk the tree to add children + if (batchMode) { + objectClass.walkForAdditions(getNewChildren, getClassTreeIfChanged) + } else { + val existingParents = + CollOps.parFlatMapKeys(newChildrenByParent)(classes.get) + for (parent <- existingParents) + parent.walkForAdditions(getNewChildren, getClassTreeIfChanged) + } + + } + + /** Optimizer part: process all methods that need reoptimizing. + * PROCESS PASS ONLY. (This IS the process pass). + */ + protected def processAllTaggedMethods(): Unit + + protected def logProcessingMethods(count: Int): Unit = + logger.debug(s"Optimizing $count methods.") + + /** Base class for [[Class]] and [[TraitImpl]]. */ + abstract class MethodContainer(val encodedName: String) { + def thisType: Type + + val myInterface = getInterface(encodedName) + + val methods = mutable.Map.empty[String, MethodImpl] + + var lastVersion: Option[String] = None + + private def reachableMethodsOf(info: Analyzer#ClassInfo): Set[String] = { + (for { + methodInfo <- info.methodInfos.values + if methodInfo.isReachable && !methodInfo.isAbstract + } yield { + methodInfo.encodedName + }).toSet + } + + /** UPDATE PASS ONLY. Global concurrency safe but not on same instance */ + def updateWith(info: Analyzer#ClassInfo, + getClassTreeIfChanged: GetClassTreeIfChanged): (Set[String], Set[String], Set[String]) = { + myInterface.ancestors = info.ancestors.map(_.encodedName).toList + + val addedMethods = Set.newBuilder[String] + val changedMethods = Set.newBuilder[String] + val deletedMethods = Set.newBuilder[String] + + val reachableMethods = reachableMethodsOf(info) + val methodSetChanged = methods.keySet != reachableMethods + if (methodSetChanged) { + // Remove deleted methods + methods retain { (methodName, method) => + if (reachableMethods.contains(methodName)) { + true + } else { + deletedMethods += methodName + method.delete() + false + } + } + // Clear lastVersion if there are new methods + if (reachableMethods.exists(!methods.contains(_))) + lastVersion = None + } + for ((tree, version) <- getClassTreeIfChanged(encodedName, lastVersion)) { + lastVersion = version + this match { + case cls: Class => + cls.isModuleClass = tree.kind == ClassKind.ModuleClass + cls.fields = for (field @ VarDef(_, _, _, _) <- tree.defs) yield field + case _ => + } + tree.defs.foreach { + case methodDef: MethodDef if methodDef.name.isInstanceOf[Ident] && + reachableMethods.contains(methodDef.name.name) => + val methodName = methodDef.name.name + + val methodInfo = info.methodInfos(methodName) + methods.get(methodName).fold { + addedMethods += methodName + val method = newMethodImpl(this, methodName) + method.updateWith(methodInfo, methodDef) + methods(methodName) = method + method + } { method => + if (method.updateWith(methodInfo, methodDef)) + changedMethods += methodName + method + } + + case _ => // ignore + } + } + + (addedMethods.result(), changedMethods.result(), deletedMethods.result()) + } + } + + /** Class in the class hierarchy (not an interface). + * A class may be a module class. + * A class knows its superclass and the interfaces it implements. It also + * maintains a list of its direct subclasses, so that the instances of + * [[Class]] form a tree of the class hierarchy. + */ + class Class(val superClass: Option[Class], + _encodedName: String) extends MethodContainer(_encodedName) { + if (encodedName == Definitions.ObjectClass) { + assert(superClass.isEmpty) + assert(objectClass == null) + } else { + assert(superClass.isDefined) + } + + /** Parent chain from this to Object. */ + val parentChain: List[Class] = + this :: superClass.fold[List[Class]](Nil)(_.parentChain) + + /** Reverse parent chain from Object to this. */ + val reverseParentChain: List[Class] = + parentChain.reverse + + def thisType: Type = ClassType(encodedName) + + var interfaces: Set[InterfaceType] = Set.empty + var subclasses: CollOps.ParIterable[Class] = CollOps.emptyParIterable + var isInstantiated: Boolean = false + + var isModuleClass: Boolean = false + var hasElidableModuleAccessor: Boolean = false + + var fields: List[VarDef] = Nil + var isInlineable: Boolean = false + var tryNewInlineable: Option[RecordValue] = None + + override def toString(): String = + encodedName + + /** Walk the class hierarchy tree for deletions. + * This includes "deleting" classes that were previously instantiated but + * are no more. + * UPDATE PASS ONLY. Not concurrency safe on same instance. + */ + def walkClassesForDeletions( + getClassInfoIfNeeded: String => Option[Analyzer#ClassInfo]): Boolean = { + def sameSuperClass(info: Analyzer#ClassInfo): Boolean = + if (info.superClass == null) superClass.isEmpty + else superClass.exists(_.encodedName == info.superClass.encodedName) + + getClassInfoIfNeeded(encodedName) match { + case Some(classInfo) if sameSuperClass(classInfo) => + // Class still exists. Recurse. + subclasses = subclasses.filter( + _.walkClassesForDeletions(getClassInfoIfNeeded)) + if (isInstantiated && !classInfo.isInstantiated) + notInstantiatedAnymore() + true + case _ => + // Class does not exist or has been moved. Delete the entire subtree. + deleteSubtree() + false + } + } + + /** Delete this class and all its subclasses. UPDATE PASS ONLY. */ + def deleteSubtree(): Unit = { + delete() + for (subclass <- subclasses) + subclass.deleteSubtree() + } + + /** UPDATE PASS ONLY. */ + private def delete(): Unit = { + if (isInstantiated) + notInstantiatedAnymore() + for (method <- methods.values) + method.delete() + classes -= encodedName + /* Note: no need to tag methods that call *statically* one of the methods + * of the deleted classes, since they've got to be invalidated by + * themselves. + */ + } + + /** UPDATE PASS ONLY. */ + def notInstantiatedAnymore(): Unit = { + assert(isInstantiated) + isInstantiated = false + for (intf <- interfaces) { + intf.removeInstantiatedSubclass(this) + for (methodName <- allMethods().keys) + intf.tagDynamicCallersOf(methodName) + } + } + + /** UPDATE PASS ONLY. */ + def walkForChanges( + getClassInfo: String => Analyzer#ClassInfo, + getClassTreeIfChanged: GetClassTreeIfChanged, + parentMethodAttributeChanges: Set[String]): Unit = { + + val classInfo = getClassInfo(encodedName) + + val (addedMethods, changedMethods, deletedMethods) = + updateWith(classInfo, getClassTreeIfChanged) + + val oldInterfaces = interfaces + val newInterfaces = + classInfo.ancestors.map(info => getInterface(info.encodedName)).toSet + interfaces = newInterfaces + + val methodAttributeChanges = + (parentMethodAttributeChanges -- methods.keys ++ + addedMethods ++ changedMethods ++ deletedMethods) + + // Tag callers with dynamic calls + val wasInstantiated = isInstantiated + isInstantiated = classInfo.isInstantiated + assert(!(wasInstantiated && !isInstantiated), + "(wasInstantiated && !isInstantiated) should have been handled "+ + "during deletion phase") + + if (isInstantiated) { + if (wasInstantiated) { + val existingInterfaces = oldInterfaces.intersect(newInterfaces) + for { + intf <- existingInterfaces + methodName <- methodAttributeChanges + } { + intf.tagDynamicCallersOf(methodName) + } + if (newInterfaces.size != oldInterfaces.size || + newInterfaces.size != existingInterfaces.size) { + val allMethodNames = allMethods().keys + for { + intf <- oldInterfaces ++ newInterfaces -- existingInterfaces + methodName <- allMethodNames + } { + intf.tagDynamicCallersOf(methodName) + } + } + } else { + val allMethodNames = allMethods().keys + for (intf <- interfaces) { + intf.addInstantiatedSubclass(this) + for (methodName <- allMethodNames) + intf.tagDynamicCallersOf(methodName) + } + } + } + + // Tag callers with static calls + for (methodName <- methodAttributeChanges) + myInterface.tagStaticCallersOf(methodName) + + // Module class specifics + updateHasElidableModuleAccessor() + + // Inlineable class + if (updateIsInlineable(classInfo)) { + for (method <- methods.values; if isConstructorName(method.encodedName)) + myInterface.tagStaticCallersOf(method.encodedName) + } + + // Recurse in subclasses + for (cls <- subclasses) + cls.walkForChanges(getClassInfo, getClassTreeIfChanged, + methodAttributeChanges) + } + + /** UPDATE PASS ONLY. */ + def walkForAdditions( + getNewChildren: String => GenIterable[Analyzer#ClassInfo], + getClassTreeIfChanged: GetClassTreeIfChanged): Unit = { + + val subclassAcc = CollOps.prepAdd(subclasses) + + for (classInfo <- getNewChildren(encodedName)) { + val cls = new Class(Some(this), classInfo.encodedName) + CollOps.add(subclassAcc, cls) + classes += classInfo.encodedName -> cls + cls.setupAfterCreation(classInfo, getClassTreeIfChanged) + cls.walkForAdditions(getNewChildren, getClassTreeIfChanged) + } + + subclasses = CollOps.finishAdd(subclassAcc) + } + + /** UPDATE PASS ONLY. */ + def updateHasElidableModuleAccessor(): Unit = { + hasElidableModuleAccessor = + isAdHocElidableModuleAccessor(encodedName) || + (isModuleClass && lookupMethod("init___").exists(isElidableModuleConstructor)) + } + + /** UPDATE PASS ONLY. */ + def updateIsInlineable(classInfo: Analyzer#ClassInfo): Boolean = { + val oldTryNewInlineable = tryNewInlineable + isInlineable = classInfo.optimizerHints.hasInlineAnnot + if (!isInlineable) { + tryNewInlineable = None + } else { + val allFields = reverseParentChain.flatMap(_.fields) + val (fieldValues, fieldTypes) = (for { + VarDef(Ident(name, originalName), tpe, mutable, rhs) <- allFields + } yield { + (rhs, RecordType.Field(name, originalName, tpe, mutable)) + }).unzip + tryNewInlineable = Some( + RecordValue(RecordType(fieldTypes), fieldValues)(Position.NoPosition)) + } + tryNewInlineable != oldTryNewInlineable + } + + /** UPDATE PASS ONLY. */ + def setupAfterCreation(classInfo: Analyzer#ClassInfo, + getClassTreeIfChanged: GetClassTreeIfChanged): Unit = { + + updateWith(classInfo, getClassTreeIfChanged) + interfaces = + classInfo.ancestors.map(info => getInterface(info.encodedName)).toSet + + isInstantiated = classInfo.isInstantiated + + if (batchMode) { + if (isInstantiated) { + /* Only add the class to all its ancestor interfaces */ + for (intf <- interfaces) + intf.addInstantiatedSubclass(this) + } + } else { + val allMethodNames = allMethods().keys + + if (isInstantiated) { + /* Add the class to all its ancestor interfaces + notify all callers + * of any of the methods. + * TODO: be more selective on methods that are notified: it is not + * necessary to modify callers of methods defined in a parent class + * that already existed in the previous run. + */ + for (intf <- interfaces) { + intf.addInstantiatedSubclass(this) + for (methodName <- allMethodNames) + intf.tagDynamicCallersOf(methodName) + } + } + + /* Tag static callers because the class could have been *moved*, + * not just added. + */ + for (methodName <- allMethodNames) + myInterface.tagStaticCallersOf(methodName) + } + + updateHasElidableModuleAccessor() + updateIsInlineable(classInfo) + } + + /** UPDATE PASS ONLY. */ + private def isElidableModuleConstructor(impl: MethodImpl): Boolean = { + def isTriviallySideEffectFree(tree: Tree): Boolean = tree match { + case _:VarRef | _:Literal | _:This => true + case _ => false + } + def isElidableStat(tree: Tree): Boolean = tree match { + case Block(stats) => + stats.forall(isElidableStat) + case Assign(Select(This(), _, _), rhs) => + isTriviallySideEffectFree(rhs) + case TraitImplApply(ClassType(traitImpl), methodName, List(This())) => + traitImpls(traitImpl).methods(methodName.name).originalDef.body match { + case Skip() => true + case _ => false + } + case StaticApply(This(), ClassType(cls), methodName, args) => + Definitions.isConstructorName(methodName.name) && + args.forall(isTriviallySideEffectFree) && + impl.owner.asInstanceOf[Class].superClass.exists { superCls => + superCls.encodedName == cls && + superCls.lookupMethod(methodName.name).exists(isElidableModuleConstructor) + } + case StoreModule(_, _) => + true + case _ => + isTriviallySideEffectFree(tree) + } + isElidableStat(impl.originalDef.body) + } + + /** All the methods of this class, including inherited ones. + * It has () so we remember this is an expensive operation. + * UPDATE PASS ONLY. + */ + def allMethods(): scala.collection.Map[String, MethodImpl] = { + val result = mutable.Map.empty[String, MethodImpl] + for (parent <- reverseParentChain) + result ++= parent.methods + result + } + + /** BOTH PASSES. */ + @tailrec + final def lookupMethod(methodName: String): Option[MethodImpl] = { + methods.get(methodName) match { + case Some(impl) => Some(impl) + case none => + superClass match { + case Some(p) => p.lookupMethod(methodName) + case none => None + } + } + } + } + + /** Trait impl. */ + class TraitImpl(_encodedName: String) extends MethodContainer(_encodedName) { + def thisType: Type = NoType + } + + /** Thing from which a [[MethodImpl]] can unregister itself from. */ + trait Unregisterable { + /** UPDATE PASS ONLY. */ + def unregisterDependee(dependee: MethodImpl): Unit + } + + /** Type of a class or interface. + * Types are created on demand when a method is called on a given + * [[ClassType]]. + * + * Fully concurrency safe unless otherwise noted. + */ + abstract class InterfaceType(val encodedName: String) extends Unregisterable { + + override def toString(): String = + s"intf $encodedName" + + /** PROCESS PASS ONLY. Concurrency safe except with + * [[addInstantiatedSubclass]] and [[removeInstantiatedSubclass]] + */ + def instantiatedSubclasses: Iterable[Class] + + /** UPDATE PASS ONLY. Concurrency safe except with + * [[instantiatedSubclasses]] + */ + def addInstantiatedSubclass(x: Class): Unit + + /** UPDATE PASS ONLY. Concurrency safe except with + * [[instantiatedSubclasses]] + */ + def removeInstantiatedSubclass(x: Class): Unit + + /** PROCESS PASS ONLY. Concurrency safe except with [[ancestors_=]] */ + def ancestors: List[String] + + /** UPDATE PASS ONLY. Not concurrency safe. */ + def ancestors_=(v: List[String]): Unit + + /** PROCESS PASS ONLY. Concurrency safe except with [[ancestors_=]]. */ + def registerAskAncestors(asker: MethodImpl): Unit + + /** PROCESS PASS ONLY. */ + def registerDynamicCaller(methodName: String, caller: MethodImpl): Unit + + /** PROCESS PASS ONLY. */ + def registerStaticCaller(methodName: String, caller: MethodImpl): Unit + + /** UPDATE PASS ONLY. */ + def tagDynamicCallersOf(methodName: String): Unit + + /** UPDATE PASS ONLY. */ + def tagStaticCallersOf(methodName: String): Unit + } + + /** A method implementation. + * It must be concrete, and belong either to a [[Class]] or a [[TraitImpl]]. + * + * A single instance is **not** concurrency safe (unless otherwise noted in + * a method comment). However, the global state modifications are + * concurrency safe. + */ + abstract class MethodImpl(val owner: MethodContainer, + val encodedName: String) extends OptimizerCore.MethodImpl + with OptimizerCore.AbstractMethodID + with Unregisterable { + private[this] var _deleted: Boolean = false + + var optimizerHints: OptimizerHints = OptimizerHints.empty + var originalDef: MethodDef = _ + var desugaredDef: JSTree = _ + var preciseInfo: Infos.MethodInfo = _ + + def thisType: Type = owner.thisType + def deleted: Boolean = _deleted + + override def toString(): String = + s"$owner.$encodedName" + + /** PROCESS PASS ONLY. */ + def registerBodyAsker(asker: MethodImpl): Unit + + /** UPDATE PASS ONLY. */ + def tagBodyAskers(): Unit + + /** PROCESS PASS ONLY. */ + private def registerAskAncestors(intf: InterfaceType): Unit = { + intf.registerAskAncestors(this) + registeredTo(intf) + } + + /** PROCESS PASS ONLY. */ + private def registerDynamicCall(intf: InterfaceType, + methodName: String): Unit = { + intf.registerDynamicCaller(methodName, this) + registeredTo(intf) + } + + /** PROCESS PASS ONLY. */ + private def registerStaticCall(intf: InterfaceType, + methodName: String): Unit = { + intf.registerStaticCaller(methodName, this) + registeredTo(intf) + } + + /** PROCESS PASS ONLY. */ + def registerAskBody(target: MethodImpl): Unit = { + target.registerBodyAsker(this) + registeredTo(target) + } + + /** PROCESS PASS ONLY. */ + protected def registeredTo(intf: Unregisterable): Unit + + /** UPDATE PASS ONLY. */ + protected def unregisterFromEverywhere(): Unit + + /** Return true iff this is the first time this method is called since the + * last reset (via [[resetTag]]). + * UPDATE PASS ONLY. + */ + protected def protectTag(): Boolean + + /** PROCESS PASS ONLY. */ + protected def resetTag(): Unit + + /** Returns true if the method's attributes changed. + * Attributes are whether it is inlineable, and whether it is a trait + * impl forwarder. Basically this is what is declared in + * [[OptimizerCore.AbstractMethodID]]. + * In the process, tags all the body askers if the body changes. + * UPDATE PASS ONLY. Not concurrency safe on same instance. + */ + def updateWith(methodInfo: Analyzer#MethodInfo, + methodDef: MethodDef): Boolean = { + assert(!_deleted, "updateWith() called on a deleted method") + + val bodyChanged = { + originalDef == null || + (methodDef.hash zip originalDef.hash).forall { + case (h1, h2) => !Hashers.hashesEqual(h1, h2, considerPositions) + } + } + + if (bodyChanged) + tagBodyAskers() + + val hints = methodInfo.optimizerHints + val changed = hints != optimizerHints || bodyChanged + if (changed) { + val oldAttributes = (inlineable, isTraitImplForwarder) + + optimizerHints = hints + originalDef = methodDef + desugaredDef = null + preciseInfo = null + updateInlineable() + tag() + + val newAttributes = (inlineable, isTraitImplForwarder) + newAttributes != oldAttributes + } else { + false + } + } + + /** UPDATE PASS ONLY. Not concurrency safe on same instance. */ + def delete(): Unit = { + assert(!_deleted, "delete() called twice") + _deleted = true + if (protectTag()) + unregisterFromEverywhere() + } + + /** Concurrency safe with itself and [[delete]] on the same instance + * + * [[tag]] can be called concurrently with [[delete]] when methods in + * traits/classes are updated. + * + * UPDATE PASS ONLY. + */ + def tag(): Unit = if (protectTag()) { + scheduleMethod(this) + unregisterFromEverywhere() + } + + /** PROCESS PASS ONLY. */ + def process(): Unit = if (!_deleted) { + val (optimizedDef, info) = new Optimizer().optimize(thisType, originalDef) + desugaredDef = + if (owner.isInstanceOf[Class]) + classEmitter.genMethod(owner.encodedName, optimizedDef) + else + classEmitter.genTraitImplMethod(owner.encodedName, optimizedDef) + preciseInfo = info + resetTag() + } + + /** All methods are PROCESS PASS ONLY */ + private class Optimizer extends OptimizerCore(semantics) { + type MethodID = MethodImpl + + val myself: MethodImpl.this.type = MethodImpl.this + + protected def getMethodBody(method: MethodID): MethodDef = { + MethodImpl.this.registerAskBody(method) + method.originalDef + } + + protected def dynamicCall(intfName: String, + methodName: String): List[MethodID] = { + val intf = getInterface(intfName) + MethodImpl.this.registerDynamicCall(intf, methodName) + intf.instantiatedSubclasses.flatMap(_.lookupMethod(methodName)).toList + } + + protected def staticCall(className: String, + methodName: String): Option[MethodID] = { + val clazz = classes(className) + MethodImpl.this.registerStaticCall(clazz.myInterface, methodName) + clazz.lookupMethod(methodName) + } + + protected def traitImplCall(traitImplName: String, + methodName: String): Option[MethodID] = { + val traitImpl = traitImpls(traitImplName) + registerStaticCall(traitImpl.myInterface, methodName) + traitImpl.methods.get(methodName) + } + + protected def getAncestorsOf(intfName: String): List[String] = { + val intf = getInterface(intfName) + registerAskAncestors(intf) + intf.ancestors + } + + protected def hasElidableModuleAccessor(moduleClassName: String): Boolean = + classes(moduleClassName).hasElidableModuleAccessor + + protected def tryNewInlineableClass(className: String): Option[RecordValue] = + classes(className).tryNewInlineable + } + } + +} + +object GenIncOptimizer { + + private val isAdHocElidableModuleAccessor = + Set("s_Predef$") + + private[optimizer] def logTime[A](logger: Logger, + title: String)(body: => A): A = { + val startTime = System.nanoTime() + val result = body + val endTime = System.nanoTime() + val elapsedTime = endTime - startTime + logger.time(title, elapsedTime) + result + } + + private[optimizer] trait AbsCollOps { + type Map[K, V] <: mutable.Map[K, V] + type ParMap[K, V] <: GenMap[K, V] + type AccMap[K, V] + type ParIterable[V] <: GenIterableLike[V, ParIterable[V]] + type Addable[V] + + def emptyAccMap[K, V]: AccMap[K, V] + def emptyMap[K, V]: Map[K, V] + def emptyParMap[K, V]: ParMap[K, V] + def emptyParIterable[V]: ParIterable[V] + + // Operations on ParMap + def put[K, V](map: ParMap[K, V], k: K, v: V): Unit + def remove[K, V](map: ParMap[K, V], k: K): Option[V] + def retain[K, V](map: ParMap[K, V])(p: (K, V) => Boolean): Unit + + // Operations on AccMap + def acc[K, V](map: AccMap[K, V], k: K, v: V): Unit + def getAcc[K, V](map: AccMap[K, V], k: K): GenIterable[V] + def parFlatMapKeys[A, B](map: AccMap[A, _])( + f: A => GenTraversableOnce[B]): GenIterable[B] + + // Operations on ParIterable + def prepAdd[V](it: ParIterable[V]): Addable[V] + def add[V](addable: Addable[V], v: V): Unit + def finishAdd[V](addable: Addable[V]): ParIterable[V] + + } + +} diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/IRChecker.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/IRChecker.scala new file mode 100644 index 0000000..6329826 --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/IRChecker.scala @@ -0,0 +1,854 @@ +/* __ *\ +** ________ ___ / / ___ __ ____ Scala.js tools ** +** / __/ __// _ | / / / _ | __ / // __/ (c) 2013-2014, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ |/_// /_\ \ http://scala-js.org/ ** +** /____/\___/_/ |_/____/_/ | |__/ /____/ ** +** |/____/ ** +\* */ + + +package scala.scalajs.tools.optimizer + +import scala.language.implicitConversions + +import scala.annotation.switch + +import scala.collection.mutable + +import scala.scalajs.ir._ +import Definitions._ +import Trees._ +import Types._ + +import scala.scalajs.tools.logging._ + +/** Checker for the validity of the IR. */ +class IRChecker(analyzer: Analyzer, allClassDefs: Seq[ClassDef], logger: Logger) { + import IRChecker._ + + private var _errorCount: Int = 0 + def errorCount: Int = _errorCount + + private val classes: mutable.Map[String, CheckedClass] = { + mutable.Map.empty[String, CheckedClass] ++= + allClassDefs.map(new CheckedClass(_)).map(c => c.name -> c) + } + + def check(): Boolean = { + for { + classDef <- allClassDefs + if analyzer.classInfos(classDef.name.name).isNeededAtAll + } { + classDef.kind match { + case ClassKind.Class | ClassKind.ModuleClass => checkClass(classDef) + case ClassKind.TraitImpl => checkTraitImpl(classDef) + case _ => + } + } + errorCount == 0 + } + + def checkClass(classDef: ClassDef): Unit = { + if (!analyzer.classInfos(classDef.name.name).isAnySubclassInstantiated) + return + + for (member <- classDef.defs) { + implicit val ctx = ErrorContext(member) + member match { + // Scala declarations + case v @ VarDef(_, _, _, _) => + checkFieldDef(v, classDef) + case m: MethodDef if m.name.isInstanceOf[Ident] => + checkMethodDef(m, classDef) + + // Exports + case m: MethodDef if m.name.isInstanceOf[StringLiteral] => + checkExportedMethodDef(m, classDef) + case member @ PropertyDef(_: StringLiteral, _, _, _) => + checkExportedPropertyDef(member, classDef) + case member @ ConstructorExportDef(_, _, _) => + checkConstructorExportDef(member, classDef) + case member @ ModuleExportDef(_) => + checkModuleExportDef(member, classDef) + + // Anything else is illegal + case _ => + reportError(s"Illegal class member of type ${member.getClass.getName}") + } + } + } + + def checkTraitImpl(classDef: ClassDef): Unit = { + for (member <- classDef.defs) { + implicit val ctx = ErrorContext(member) + member match { + case m: MethodDef => + checkMethodDef(m, classDef) + case _ => + reportError(s"Invalid member for a TraitImpl") + } + } + } + + def checkFieldDef(fieldDef: VarDef, classDef: ClassDef): Unit = { + val VarDef(name, tpe, mutable, rhs) = fieldDef + implicit val ctx = ErrorContext(fieldDef) + + if (tpe == NoType) + reportError(s"VarDef cannot have type NoType") + else + typecheckExpect(rhs, Env.empty, tpe) + } + + def checkMethodDef(methodDef: MethodDef, classDef: ClassDef): Unit = { + val MethodDef(Ident(name, _), params, resultType, body) = methodDef + implicit val ctx = ErrorContext(methodDef) + + if (!analyzer.classInfos(classDef.name.name).methodInfos(name).isReachable) + return + + for (ParamDef(name, tpe, _) <- params) + if (tpe == NoType) + reportError(s"Parameter $name has type NoType") + + val resultTypeForSig = + if (isConstructorName(name)) NoType + else resultType + + val advertizedSig = (params.map(_.ptpe), resultTypeForSig) + val sigFromName = inferMethodType(name, + inTraitImpl = classDef.kind == ClassKind.TraitImpl) + if (advertizedSig != sigFromName) { + reportError( + s"The signature of ${classDef.name.name}.$name, which is "+ + s"$advertizedSig, does not match its name (should be $sigFromName).") + } + + val thisType = + if (!classDef.kind.isClass) NoType + else ClassType(classDef.name.name) + val bodyEnv = Env.fromSignature(thisType, params, resultType) + if (resultType == NoType) + typecheckStat(body, bodyEnv) + else + typecheckExpect(body, bodyEnv, resultType) + } + + def checkExportedMethodDef(methodDef: MethodDef, classDef: ClassDef): Unit = { + val MethodDef(_, params, resultType, body) = methodDef + implicit val ctx = ErrorContext(methodDef) + + if (!classDef.kind.isClass) { + reportError(s"Exported method def can only appear in a class") + return + } + + for (ParamDef(name, tpe, _) <- params) { + if (tpe == NoType) + reportError(s"Parameter $name has type NoType") + else if (tpe != AnyType) + reportError(s"Parameter $name of exported method def has type $tpe, "+ + "but must be Any") + } + + if (resultType != AnyType) { + reportError(s"Result type of exported method def is $resultType, "+ + "but must be Any") + } + + val thisType = ClassType(classDef.name.name) + val bodyEnv = Env.fromSignature(thisType, params, resultType) + .withArgumentsVar(methodDef.pos) + typecheckExpect(body, bodyEnv, resultType) + } + + def checkExportedPropertyDef(propDef: PropertyDef, classDef: ClassDef): Unit = { + val PropertyDef(_, getterBody, setterArg, setterBody) = propDef + implicit val ctx = ErrorContext(propDef) + + if (!classDef.kind.isClass) { + reportError(s"Exported property def can only appear in a class") + return + } + + val thisType = ClassType(classDef.name.name) + + if (getterBody != EmptyTree) { + val getterBodyEnv = Env.fromSignature(thisType, Nil, AnyType) + typecheckExpect(getterBody, getterBodyEnv, AnyType) + } + + if (setterBody != EmptyTree) { + if (setterArg.ptpe != AnyType) + reportError("Setter argument of exported property def has type "+ + s"${setterArg.ptpe}, but must be Any") + + val setterBodyEnv = Env.fromSignature(thisType, List(setterArg), NoType) + typecheckStat(setterBody, setterBodyEnv) + } + } + + def checkConstructorExportDef(ctorDef: ConstructorExportDef, + classDef: ClassDef): Unit = { + val ConstructorExportDef(_, params, body) = ctorDef + implicit val ctx = ErrorContext(ctorDef) + + if (!classDef.kind.isClass) { + reportError(s"Exported constructor def can only appear in a class") + return + } + + for (ParamDef(name, tpe, _) <- params) { + if (tpe == NoType) + reportError(s"Parameter $name has type NoType") + else if (tpe != AnyType) + reportError(s"Parameter $name of exported constructor def has type "+ + s"$tpe, but must be Any") + } + + val thisType = ClassType(classDef.name.name) + val bodyEnv = Env.fromSignature(thisType, params, NoType) + .withArgumentsVar(ctorDef.pos) + typecheckStat(body, bodyEnv) + } + + def checkModuleExportDef(moduleDef: ModuleExportDef, + classDef: ClassDef): Unit = { + implicit val ctx = ErrorContext(moduleDef) + + if (classDef.kind != ClassKind.ModuleClass) + reportError(s"Exported module def can only appear in a module class") + } + + def typecheckStat(tree: Tree, env: Env): Env = { + implicit val ctx = ErrorContext(tree) + + tree match { + case VarDef(ident, vtpe, mutable, rhs) => + typecheckExpect(rhs, env, vtpe) + env.withLocal(LocalDef(ident.name, vtpe, mutable)(tree.pos)) + + case Skip() => + env + + case Assign(select, rhs) => + select match { + case Select(_, Ident(name, _), false) => + /* TODO In theory this case would verify that we never assign to + * an immutable field. But we cannot do that because we *do* emit + * such assigns in constructors. + * In the future we might want to check that only these legal + * special cases happen, and nothing else. But it seems non-trivial + * to do so, so currently we trust scalac not to make us emit + * illegal assigns. + */ + //reportError(s"Assignment to immutable field $name.") + case VarRef(Ident(name, _), false) => + reportError(s"Assignment to immutable variable $name.") + case _ => + } + val lhsTpe = typecheckExpr(select, env) + val expectedRhsTpe = select match { + case _:JSDotSelect | _:JSBracketSelect => AnyType + case _ => lhsTpe + } + typecheckExpect(rhs, env, expectedRhsTpe) + env + + case StoreModule(cls, value) => + if (!cls.className.endsWith("$")) + reportError("StoreModule of non-module class $cls") + typecheckExpect(value, env, ClassType(cls.className)) + env + + case Block(stats) => + (env /: stats) { (prevEnv, stat) => + typecheckStat(stat, prevEnv) + } + env + + case Labeled(label, NoType, body) => + typecheckStat(body, env.withLabeledReturnType(label.name, AnyType)) + env + + case If(cond, thenp, elsep) => + typecheckExpect(cond, env, BooleanType) + typecheckStat(thenp, env) + typecheckStat(elsep, env) + env + + case While(cond, body, label) => + typecheckExpect(cond, env, BooleanType) + typecheckStat(body, env) + env + + case DoWhile(body, cond, label) => + typecheckStat(body, env) + typecheckExpect(cond, env, BooleanType) + env + + case Try(block, errVar, handler, finalizer) => + typecheckStat(block, env) + if (handler != EmptyTree) { + val handlerEnv = + env.withLocal(LocalDef(errVar.name, AnyType, false)(errVar.pos)) + typecheckStat(handler, handlerEnv) + } + if (finalizer != EmptyTree) { + typecheckStat(finalizer, env) + } + env + + case Match(selector, cases, default) => + typecheckExpr(selector, env) + for ((alts, body) <- cases) { + alts.foreach(typecheckExpr(_, env)) + typecheckStat(body, env) + } + typecheckStat(default, env) + env + + case Debugger() => + env + + case JSDelete(JSDotSelect(obj, prop)) => + typecheckExpr(obj, env) + env + + case JSDelete(JSBracketSelect(obj, prop)) => + typecheckExpr(obj, env) + typecheckExpr(prop, env) + env + + case _ => + typecheck(tree, env) + env + } + } + + def typecheckExpect(tree: Tree, env: Env, expectedType: Type)( + implicit ctx: ErrorContext): Unit = { + val tpe = typecheckExpr(tree, env) + if (!isSubtype(tpe, expectedType)) + reportError(s"$expectedType expected but $tpe found "+ + s"for tree of type ${tree.getClass.getName}") + } + + def typecheckExpr(tree: Tree, env: Env): Type = { + implicit val ctx = ErrorContext(tree) + if (tree.tpe == NoType) + reportError(s"Expression tree has type NoType") + typecheck(tree, env) + } + + def typecheck(tree: Tree, env: Env): Type = { + implicit val ctx = ErrorContext(tree) + + def checkApplyGeneric(methodName: String, methodFullName: String, + args: List[Tree], inTraitImpl: Boolean): Unit = { + val (methodParams, resultType) = inferMethodType(methodName, inTraitImpl) + if (args.size != methodParams.size) + reportError(s"Arity mismatch: ${methodParams.size} expected but "+ + s"${args.size} found") + for ((actual, formal) <- args zip methodParams) { + typecheckExpect(actual, env, formal) + } + if (!isConstructorName(methodName) && tree.tpe != resultType) + reportError(s"Call to $methodFullName of type $resultType "+ + s"typed as ${tree.tpe}") + } + + tree match { + // Control flow constructs + + case Block(statsAndExpr) => + val stats :+ expr = statsAndExpr + val envAfterStats = (env /: stats) { (prevEnv, stat) => + typecheckStat(stat, prevEnv) + } + typecheckExpr(expr, envAfterStats) + + case Labeled(label, tpe, body) => + typecheckExpect(body, env.withLabeledReturnType(label.name, tpe), tpe) + + case Return(expr, label) => + env.returnTypes.get(label.map(_.name)).fold[Unit] { + reportError(s"Cannot return to label $label.") + typecheckExpr(expr, env) + } { returnType => + typecheckExpect(expr, env, returnType) + } + + case If(cond, thenp, elsep) => + val tpe = tree.tpe + typecheckExpect(cond, env, BooleanType) + typecheckExpect(thenp, env, tpe) + typecheckExpect(elsep, env, tpe) + + case While(BooleanLiteral(true), body, label) if tree.tpe == NothingType => + typecheckStat(body, env) + + case Try(block, errVar, handler, finalizer) => + val tpe = tree.tpe + typecheckExpect(block, env, tpe) + if (handler != EmptyTree) { + val handlerEnv = + env.withLocal(LocalDef(errVar.name, AnyType, false)(errVar.pos)) + typecheckExpect(handler, handlerEnv, tpe) + } + if (finalizer != EmptyTree) { + typecheckStat(finalizer, env) + } + + case Throw(expr) => + typecheckExpr(expr, env) + + case Continue(label) => + /* Here we could check that it is indeed legal to break to the + * specified label. However, if we do anything illegal here, it will + * result in a SyntaxError in JavaScript anyway, so we do not really + * care. + */ + + case Match(selector, cases, default) => + val tpe = tree.tpe + typecheckExpr(selector, env) + for ((alts, body) <- cases) { + alts.foreach(typecheckExpr(_, env)) + typecheckExpect(body, env, tpe) + } + typecheckExpect(default, env, tpe) + + // Scala expressions + + case New(cls, ctor, args) => + val clazz = lookupClass(cls) + if (!clazz.kind.isClass) + reportError(s"new $cls which is not a class") + checkApplyGeneric(ctor.name, s"$cls.$ctor", args, + inTraitImpl = false) + + case LoadModule(cls) => + if (!cls.className.endsWith("$")) + reportError("LoadModule of non-module class $cls") + + case Select(qualifier, Ident(item, _), mutable) => + val qualType = typecheckExpr(qualifier, env) + qualType match { + case ClassType(cls) => + val clazz = lookupClass(cls) + if (!clazz.kind.isClass) { + reportError(s"Cannot select $item of non-class $cls") + } else { + clazz.lookupField(item).fold[Unit] { + reportError(s"Class $cls does not have a field $item") + } { fieldDef => + if (fieldDef.tpe != tree.tpe) + reportError(s"Select $cls.$item of type "+ + s"${fieldDef.tpe} typed as ${tree.tpe}") + if (fieldDef.mutable != mutable) + reportError(s"Select $cls.$item with "+ + s"mutable=${fieldDef.mutable} marked as mutable=$mutable") + } + } + case NullType | NothingType => + // always ok + case _ => + reportError(s"Cannot select $item of non-class type $qualType") + } + + case Apply(receiver, Ident(method, _), args) => + val receiverType = typecheckExpr(receiver, env) + checkApplyGeneric(method, s"$receiverType.$method", args, + inTraitImpl = false) + + case StaticApply(receiver, cls, Ident(method, _), args) => + typecheckExpect(receiver, env, cls) + checkApplyGeneric(method, s"$cls.$method", args, inTraitImpl = false) + + case TraitImplApply(impl, Ident(method, _), args) => + val clazz = lookupClass(impl) + if (clazz.kind != ClassKind.TraitImpl) + reportError(s"Cannot trait-impl apply method of non-trait-impl $impl") + checkApplyGeneric(method, s"$impl.$method", args, inTraitImpl = true) + + case UnaryOp(op, lhs) => + import UnaryOp._ + (op: @switch) match { + case `typeof` => + typecheckExpr(lhs, env) + case IntToLong => + typecheckExpect(lhs, env, IntType) + case LongToInt | LongToDouble => + typecheckExpect(lhs, env, LongType) + case DoubleToInt | DoubleToFloat | DoubleToLong => + typecheckExpect(lhs, env, DoubleType) + case Boolean_! => + typecheckExpect(lhs, env, BooleanType) + } + + case BinaryOp(op, lhs, rhs) => + import BinaryOp._ + (op: @switch) match { + case === | !== | String_+ => + typecheckExpr(lhs, env) + typecheckExpr(rhs, env) + case `in` => + typecheckExpect(lhs, env, ClassType(StringClass)) + typecheckExpr(rhs, env) + case `instanceof` => + typecheckExpr(lhs, env) + typecheckExpr(rhs, env) + case Int_+ | Int_- | Int_* | Int_/ | Int_% | + Int_| | Int_& | Int_^ | Int_<< | Int_>>> | Int_>> => + typecheckExpect(lhs, env, IntType) + typecheckExpect(rhs, env, IntType) + case Float_+ | Float_- | Float_* | Float_/ | Float_% => + typecheckExpect(lhs, env, FloatType) + typecheckExpect(lhs, env, FloatType) + case Long_+ | Long_- | Long_* | Long_/ | Long_% | + Long_| | Long_& | Long_^ | + Long_== | Long_!= | Long_< | Long_<= | Long_> | Long_>= => + typecheckExpect(lhs, env, LongType) + typecheckExpect(rhs, env, LongType) + case Long_<< | Long_>>> | Long_>> => + typecheckExpect(lhs, env, LongType) + typecheckExpect(rhs, env, IntType) + case Double_+ | Double_- | Double_* | Double_/ | Double_% | + Num_== | Num_!= | Num_< | Num_<= | Num_> | Num_>= => + typecheckExpect(lhs, env, DoubleType) + typecheckExpect(lhs, env, DoubleType) + case Boolean_== | Boolean_!= | Boolean_| | Boolean_& => + typecheckExpect(lhs, env, BooleanType) + typecheckExpect(rhs, env, BooleanType) + } + + case NewArray(tpe, lengths) => + for (length <- lengths) + typecheckExpect(length, env, IntType) + + case ArrayValue(tpe, elems) => + val elemType = arrayElemType(tpe) + for (elem <- elems) + typecheckExpect(elem, env, elemType) + + case ArrayLength(array) => + val arrayType = typecheckExpr(array, env) + if (!arrayType.isInstanceOf[ArrayType]) + reportError(s"Array type expected but $arrayType found") + + case ArraySelect(array, index) => + typecheckExpect(index, env, IntType) + typecheckExpr(array, env) match { + case arrayType: ArrayType => + if (tree.tpe != arrayElemType(arrayType)) + reportError(s"Array select of array type $arrayType typed as ${tree.tpe}") + case arrayType => + reportError(s"Array type expected but $arrayType found") + } + + case IsInstanceOf(expr, cls) => + typecheckExpr(expr, env) + + case AsInstanceOf(expr, cls) => + typecheckExpr(expr, env) + + case Unbox(expr, _) => + typecheckExpr(expr, env) + + case GetClass(expr) => + typecheckExpr(expr, env) + + // JavaScript expressions + + case JSNew(ctor, args) => + typecheckExpr(ctor, env) + for (arg <- args) + typecheckExpr(arg, env) + + case JSDotSelect(qualifier, item) => + typecheckExpr(qualifier, env) + + case JSBracketSelect(qualifier, item) => + typecheckExpr(qualifier, env) + typecheckExpr(item, env) + + case JSFunctionApply(fun, args) => + typecheckExpr(fun, env) + for (arg <- args) + typecheckExpr(arg, env) + + case JSDotMethodApply(receiver, method, args) => + typecheckExpr(receiver, env) + for (arg <- args) + typecheckExpr(arg, env) + + case JSBracketMethodApply(receiver, method, args) => + typecheckExpr(receiver, env) + typecheckExpr(method, env) + for (arg <- args) + typecheckExpr(arg, env) + + case JSUnaryOp(op, lhs) => + typecheckExpr(lhs, env) + + case JSBinaryOp(op, lhs, rhs) => + typecheckExpr(lhs, env) + typecheckExpr(rhs, env) + + case JSArrayConstr(items) => + for (item <- items) + typecheckExpr(item, env) + + case JSObjectConstr(fields) => + for ((_, value) <- fields) + typecheckExpr(value, env) + + case JSEnvInfo() => + + // Literals + + case _: Literal => + + // Atomic expressions + + case VarRef(Ident(name, _), mutable) => + env.locals.get(name).fold[Unit] { + reportError(s"Cannot find variable $name in scope") + } { localDef => + if (tree.tpe != localDef.tpe) + reportError(s"Variable $name of type ${localDef.tpe} "+ + s"typed as ${tree.tpe}") + if (mutable != localDef.mutable) + reportError(s"Variable $name with mutable=${localDef.mutable} "+ + s"marked as mutable=$mutable") + } + + case This() => + if (!isSubtype(env.thisTpe, tree.tpe)) + reportError(s"this of type ${env.thisTpe} typed as ${tree.tpe}") + + case Closure(captureParams, params, body, captureValues) => + if (captureParams.size != captureValues.size) + reportError("Mismatched size for captures: "+ + s"${captureParams.size} params vs ${captureValues.size} values") + + for ((ParamDef(name, ctpe, mutable), value) <- captureParams zip captureValues) { + if (mutable) + reportError(s"Capture parameter $name cannot be mutable") + if (ctpe == NoType) + reportError(s"Parameter $name has type NoType") + else + typecheckExpect(value, env, ctpe) + } + + for (ParamDef(name, ptpe, mutable) <- params) { + if (ptpe == NoType) + reportError(s"Parameter $name has type NoType") + else if (ptpe != AnyType) + reportError(s"Closure parameter $name has type $ptpe instead of any") + } + + val bodyEnv = Env.fromSignature( + AnyType, captureParams ++ params, AnyType) + typecheckExpect(body, bodyEnv, AnyType) + + case _ => + reportError(s"Invalid expression tree") + } + + tree.tpe + } + + def inferMethodType(encodedName: String, inTraitImpl: Boolean)( + implicit ctx: ErrorContext): (List[Type], Type) = { + def dropPrivateMarker(params: List[String]): List[String] = + if (params.nonEmpty && params.head.startsWith("p")) params.tail + else params + + if (isConstructorName(encodedName)) { + assert(!inTraitImpl, "Trait impl should not have a constructor") + val params = dropPrivateMarker( + encodedName.stripPrefix("init___").split("__").toList) + if (params == List("")) (Nil, NoType) + else (params.map(decodeType), NoType) + } else if (isReflProxyName(encodedName)) { + assert(!inTraitImpl, "Trait impl should not have refl proxy methods") + val params = dropPrivateMarker(encodedName.split("__").toList.tail) + (params.map(decodeType), AnyType) + } else { + val paramsAndResult0 = + encodedName.split("__").toList.tail + val paramsAndResult1 = + if (inTraitImpl) paramsAndResult0.tail + else paramsAndResult0 + val paramsAndResult = + dropPrivateMarker(paramsAndResult1) + (paramsAndResult.init.map(decodeType), decodeType(paramsAndResult.last)) + } + } + + def decodeType(encodedName: String)(implicit ctx: ErrorContext): Type = { + if (encodedName.isEmpty) NoType + else if (encodedName.charAt(0) == 'A') { + // array type + val dims = encodedName.indexWhere(_ != 'A') + val base = encodedName.substring(dims) + ArrayType(base, dims) + } else if (encodedName.length == 1) { + (encodedName.charAt(0): @switch) match { + case 'V' => NoType + case 'Z' => BooleanType + case 'C' | 'B' | 'S' | 'I' => IntType + case 'J' => LongType + case 'F' => FloatType + case 'D' => DoubleType + case 'O' => AnyType + case 'T' => ClassType(StringClass) // NOT StringType + } + } else if (encodedName == "sr_Nothing$") { + NothingType + } else if (encodedName == "sr_Null$") { + NullType + } else { + val clazz = lookupClass(encodedName) + if (clazz.kind == ClassKind.RawJSType) AnyType + else ClassType(encodedName) + } + } + + def arrayElemType(arrayType: ArrayType)(implicit ctx: ErrorContext): Type = { + if (arrayType.dimensions == 1) decodeType(arrayType.baseClassName) + else ArrayType(arrayType.baseClassName, arrayType.dimensions-1) + } + + def reportError(msg: String)(implicit ctx: ErrorContext): Unit = { + logger.error(s"$ctx: $msg") + _errorCount += 1 + } + + def lookupClass(className: String)(implicit ctx: ErrorContext): CheckedClass = { + classes.getOrElseUpdate(className, { + reportError(s"Cannot find class $className") + new CheckedClass(className, ClassKind.Class, + Some(ObjectClass), Set(ObjectClass)) + }) + } + + def lookupClass(classType: ClassType)(implicit ctx: ErrorContext): CheckedClass = + lookupClass(classType.className) + + def isSubclass(lhs: String, rhs: String)(implicit ctx: ErrorContext): Boolean = { + lookupClass(lhs).isSubclass(lookupClass(rhs)) + } + + def isSubtype(lhs: Type, rhs: Type)(implicit ctx: ErrorContext): Boolean = { + Types.isSubtype(lhs, rhs)(isSubclass) + } + + class Env( + /** Type of `this`. Can be NoType. */ + val thisTpe: Type, + /** Local variables in scope (including through closures). */ + val locals: Map[String, LocalDef], + /** Return types by label. */ + val returnTypes: Map[Option[String], Type] + ) { + def withThis(thisTpe: Type): Env = + new Env(thisTpe, this.locals, this.returnTypes) + + def withLocal(localDef: LocalDef): Env = + new Env(thisTpe, locals + (localDef.name -> localDef), returnTypes) + + def withLocals(localDefs: TraversableOnce[LocalDef]): Env = + new Env(thisTpe, locals ++ localDefs.map(d => d.name -> d), returnTypes) + + def withReturnType(returnType: Type): Env = + new Env(this.thisTpe, this.locals, returnTypes + (None -> returnType)) + + def withLabeledReturnType(label: String, returnType: Type): Env = + new Env(this.thisTpe, this.locals, returnTypes + (Some(label) -> returnType)) + + def withArgumentsVar(pos: Position): Env = + withLocal(LocalDef("arguments", AnyType, mutable = false)(pos)) + } + + object Env { + val empty: Env = new Env(NoType, Map.empty, Map.empty) + + def fromSignature(thisType: Type, params: List[ParamDef], + resultType: Type): Env = { + val paramLocalDefs = + for (p @ ParamDef(name, tpe, mutable) <- params) yield + name.name -> LocalDef(name.name, tpe, mutable)(p.pos) + new Env(thisType, paramLocalDefs.toMap, + Map(None -> (if (resultType == NoType) AnyType else resultType))) + } + } + + class CheckedClass( + val name: String, + val kind: ClassKind, + val superClassName: Option[String], + val ancestors: Set[String], + _fields: TraversableOnce[CheckedField] = Nil) { + + val fields = _fields.map(f => f.name -> f).toMap + + lazy val superClass = superClassName.map(classes) + + def this(classDef: ClassDef) = { + this(classDef.name.name, classDef.kind, + classDef.parent.map(_.name), + classDef.ancestors.map(_.name).toSet, + CheckedClass.collectFields(classDef)) + } + + def isSubclass(that: CheckedClass): Boolean = + this == that || ancestors.contains(that.name) + + def isAncestorOfHijackedClass: Boolean = + AncestorsOfHijackedClasses.contains(name) + + def lookupField(name: String): Option[CheckedField] = + fields.get(name).orElse(superClass.flatMap(_.lookupField(name))) + } + + object CheckedClass { + private def collectFields(classDef: ClassDef) = { + classDef.defs collect { + case VarDef(Ident(name, _), tpe, mutable, _) => + new CheckedField(name, tpe, mutable) + } + } + } + + class CheckedField(val name: String, val tpe: Type, val mutable: Boolean) +} + +object IRChecker { + private final class ErrorContext(val tree: Tree) extends AnyVal { + override def toString(): String = { + val pos = tree.pos + s"${pos.source}(${pos.line+1}:${pos.column+1}:${tree.getClass.getSimpleName})" + } + + def pos: Position = tree.pos + } + + private object ErrorContext { + implicit def tree2errorContext(tree: Tree): ErrorContext = + ErrorContext(tree) + + def apply(tree: Tree): ErrorContext = + new ErrorContext(tree) + } + + private def isConstructorName(name: String): Boolean = + name.startsWith("init___") + + private def isReflProxyName(name: String): Boolean = + name.endsWith("__") && !isConstructorName(name) + + case class LocalDef(name: String, tpe: Type, mutable: Boolean)(val pos: Position) +} diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/IncOptimizer.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/IncOptimizer.scala new file mode 100644 index 0000000..d115618 --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/IncOptimizer.scala @@ -0,0 +1,158 @@ +/* __ *\ +** ________ ___ / / ___ __ ____ Scala.js tools ** +** / __/ __// _ | / / / _ | __ / // __/ (c) 2013-2014, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ |/_// /_\ \ http://scala-js.org/ ** +** /____/\___/_/ |_/____/_/ | |__/ /____/ ** +** |/____/ ** +\* */ + + +package scala.scalajs.tools.optimizer + +import scala.collection.{GenTraversableOnce, GenIterable} +import scala.collection.mutable + +import scala.scalajs.tools.sem.Semantics + +class IncOptimizer(semantics: Semantics) extends GenIncOptimizer(semantics) { + + protected object CollOps extends GenIncOptimizer.AbsCollOps { + type Map[K, V] = mutable.Map[K, V] + type ParMap[K, V] = mutable.Map[K, V] + type AccMap[K, V] = mutable.Map[K, mutable.ListBuffer[V]] + type ParIterable[V] = mutable.ListBuffer[V] + type Addable[V] = mutable.ListBuffer[V] + + def emptyAccMap[K, V]: AccMap[K, V] = mutable.Map.empty + def emptyMap[K, V]: Map[K, V] = mutable.Map.empty + def emptyParMap[K, V]: ParMap[K, V] = mutable.Map.empty + def emptyParIterable[V]: ParIterable[V] = mutable.ListBuffer.empty + + // Operations on ParMap + def put[K, V](map: ParMap[K, V], k: K, v: V): Unit = map.put(k, v) + def remove[K, V](map: ParMap[K, V], k: K): Option[V] = map.remove(k) + + def retain[K, V](map: ParMap[K, V])(p: (K, V) => Boolean): Unit = + map.retain(p) + + // Operations on AccMap + def acc[K, V](map: AccMap[K, V], k: K, v: V): Unit = + map.getOrElseUpdate(k, mutable.ListBuffer.empty) += v + + def getAcc[K, V](map: AccMap[K, V], k: K): GenIterable[V] = + map.getOrElse(k, Nil) + + def parFlatMapKeys[A, B](map: AccMap[A, _])( + f: A => GenTraversableOnce[B]): GenIterable[B] = + map.keys.flatMap(f).toList + + // Operations on ParIterable + def prepAdd[V](it: ParIterable[V]): Addable[V] = it + def add[V](addable: Addable[V], v: V): Unit = addable += v + def finishAdd[V](addable: Addable[V]): ParIterable[V] = addable + } + + private val _interfaces = mutable.Map.empty[String, InterfaceType] + protected def getInterface(encodedName: String): InterfaceType = + _interfaces.getOrElseUpdate(encodedName, new SeqInterfaceType(encodedName)) + + private val methodsToProcess = mutable.ListBuffer.empty[MethodImpl] + protected def scheduleMethod(method: MethodImpl): Unit = + methodsToProcess += method + + protected def newMethodImpl(owner: MethodContainer, + encodedName: String): MethodImpl = new SeqMethodImpl(owner, encodedName) + + protected def processAllTaggedMethods(): Unit = { + logProcessingMethods(methodsToProcess.count(!_.deleted)) + for (method <- methodsToProcess) + method.process() + methodsToProcess.clear() + } + + private class SeqInterfaceType(encName: String) extends InterfaceType(encName) { + private val ancestorsAskers = mutable.Set.empty[MethodImpl] + private val dynamicCallers = mutable.Map.empty[String, mutable.Set[MethodImpl]] + private val staticCallers = mutable.Map.empty[String, mutable.Set[MethodImpl]] + + private var _ancestors: List[String] = encodedName :: Nil + + private var _instantiatedSubclasses: Set[Class] = Set.empty + + def instantiatedSubclasses: Iterable[Class] = _instantiatedSubclasses + + def addInstantiatedSubclass(x: Class): Unit = + _instantiatedSubclasses += x + + def removeInstantiatedSubclass(x: Class): Unit = + _instantiatedSubclasses -= x + + def ancestors: List[String] = _ancestors + + def ancestors_=(v: List[String]): Unit = { + if (v != _ancestors) { + _ancestors = v + ancestorsAskers.foreach(_.tag()) + ancestorsAskers.clear() + } + } + + def registerAskAncestors(asker: MethodImpl): Unit = + ancestorsAskers += asker + + def registerDynamicCaller(methodName: String, caller: MethodImpl): Unit = + dynamicCallers.getOrElseUpdate(methodName, mutable.Set.empty) += caller + + def registerStaticCaller(methodName: String, caller: MethodImpl): Unit = + staticCallers.getOrElseUpdate(methodName, mutable.Set.empty) += caller + + def unregisterDependee(dependee: MethodImpl): Unit = { + ancestorsAskers -= dependee + dynamicCallers.values.foreach(_ -= dependee) + staticCallers.values.foreach(_ -= dependee) + } + + def tagDynamicCallersOf(methodName: String): Unit = + dynamicCallers.remove(methodName).foreach(_.foreach(_.tag())) + + def tagStaticCallersOf(methodName: String): Unit = + staticCallers.remove(methodName).foreach(_.foreach(_.tag())) + } + + private class SeqMethodImpl(owner: MethodContainer, + encodedName: String) extends MethodImpl(owner, encodedName) { + + private val bodyAskers = mutable.Set.empty[MethodImpl] + + def registerBodyAsker(asker: MethodImpl): Unit = + bodyAskers += asker + + def unregisterDependee(dependee: MethodImpl): Unit = + bodyAskers -= dependee + + def tagBodyAskers(): Unit = { + bodyAskers.foreach(_.tag()) + bodyAskers.clear() + } + + private var _registeredTo: List[Unregisterable] = Nil + private var tagged = false + + protected def registeredTo(intf: Unregisterable): Unit = + _registeredTo ::= intf + + protected def unregisterFromEverywhere(): Unit = { + _registeredTo.foreach(_.unregisterDependee(this)) + _registeredTo = Nil + } + + protected def protectTag(): Boolean = { + val res = !tagged + tagged = true + res + } + protected def resetTag(): Unit = tagged = false + + } + +} diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/JSTreeBuilder.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/JSTreeBuilder.scala new file mode 100644 index 0000000..3d37a56 --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/JSTreeBuilder.scala @@ -0,0 +1,16 @@ +package scala.scalajs.tools.optimizer + +import scala.scalajs.ir +import scala.scalajs.tools.javascript + +/** An abstract builder taking IR or JSTrees */ +trait JSTreeBuilder { + /** Add a JavaScript tree representing a statement. + * The tree must be a valid JavaScript tree (typically obtained by + * desugaring a full-fledged IR tree). + */ + def addJSTree(tree: javascript.Trees.Tree): Unit + + /** Completes the builder. */ + def complete(): Unit = () +} diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/OptimizerCore.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/OptimizerCore.scala new file mode 100644 index 0000000..364038b --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/OptimizerCore.scala @@ -0,0 +1,3572 @@ +/* __ *\ +** ________ ___ / / ___ __ ____ Scala.js tools ** +** / __/ __// _ | / / / _ | __ / // __/ (c) 2013-2014, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ |/_// /_\ \ http://scala-js.org/ ** +** /____/\___/_/ |_/____/_/ | |__/ /____/ ** +** |/____/ ** +\* */ + + +package scala.scalajs.tools.optimizer + +import scala.language.implicitConversions + +import scala.annotation.{switch, tailrec} + +import scala.collection.mutable + +import scala.util.control.{NonFatal, ControlThrowable, TailCalls} +import scala.util.control.TailCalls.{done => _, _} // done is a too generic term + +import scala.scalajs.ir._ +import Definitions.{ObjectClass, isConstructorName, isReflProxyName} +import Infos.OptimizerHints +import Trees._ +import Types._ + +import scala.scalajs.tools.sem.Semantics +import scala.scalajs.tools.javascript.LongImpl +import scala.scalajs.tools.logging._ + +/** Optimizer core. + * Designed to be "mixed in" [[IncOptimizer#MethodImpl#Optimizer]]. + * This is the core of the optimizer. It contains all the smart things the + * optimizer does. To perform inlining, it relies on abstract protected + * methods to identify the target of calls. + */ +private[optimizer] abstract class OptimizerCore(semantics: Semantics) { + import OptimizerCore._ + + type MethodID <: AbstractMethodID + + val myself: MethodID + + /** Returns the body of a method. */ + protected def getMethodBody(method: MethodID): MethodDef + + /** Returns the list of possible targets for a dynamically linked call. */ + protected def dynamicCall(intfName: String, + methodName: String): List[MethodID] + + /** Returns the target of a static call. */ + protected def staticCall(className: String, + methodName: String): Option[MethodID] + + /** Returns the target of a trait impl call. */ + protected def traitImplCall(traitImplName: String, + methodName: String): Option[MethodID] + + /** Returns the list of ancestors of a class or interface. */ + protected def getAncestorsOf(encodedName: String): List[String] + + /** Tests whether the given module class has an elidable accessor. + * In other words, whether it is safe to discard a LoadModule of that + * module class which is not used. + */ + protected def hasElidableModuleAccessor(moduleClassName: String): Boolean + + /** Tests whether the given class is inlineable. + * @return None if the class is not inlineable, Some(value) if it is, where + * value is a RecordValue with the initial value of its fields. + */ + protected def tryNewInlineableClass(className: String): Option[RecordValue] + + private val usedLocalNames = mutable.Set.empty[String] + private val usedLabelNames = mutable.Set.empty[String] + private var statesInUse: List[State[_]] = Nil + + private var disableOptimisticOptimizations: Boolean = false + private var rollbacksCount: Int = 0 + + private val attemptedInlining = mutable.ListBuffer.empty[MethodID] + + private var curTrampolineId = 0 + + def optimize(thisType: Type, originalDef: MethodDef): (MethodDef, Infos.MethodInfo) = { + try { + val MethodDef(name, params, resultType, body) = originalDef + val (newParams, newBody) = try { + transformIsolatedBody(Some(myself), thisType, params, resultType, body) + } catch { + case _: TooManyRollbacksException => + usedLocalNames.clear() + usedLabelNames.clear() + statesInUse = Nil + disableOptimisticOptimizations = true + transformIsolatedBody(Some(myself), thisType, params, resultType, body) + } + val m = MethodDef(name, newParams, resultType, newBody)(None)(originalDef.pos) + val info = recreateInfo(m) + (m, info) + } catch { + case NonFatal(cause) => + throw new OptimizeException(myself, attemptedInlining.distinct.toList, cause) + case e: Throwable => + // This is a fatal exception. Don't wrap, just output debug info error + Console.err.println(exceptionMsg(myself, attemptedInlining.distinct.toList)) + throw e + } + } + + private def withState[A, B](state: State[A])(body: => B): B = { + statesInUse ::= state + try body + finally statesInUse = statesInUse.tail + } + + private def freshLocalName(base: String): String = + freshNameGeneric(usedLocalNames, base) + + private def freshLabelName(base: String): String = + freshNameGeneric(usedLabelNames, base) + + private val isReserved = isKeyword ++ Seq("arguments", "eval", "ScalaJS") + + private def freshNameGeneric(usedNames: mutable.Set[String], base: String): String = { + val result = if (!usedNames.contains(base) && !isReserved(base)) { + base + } else { + var i = 1 + while (usedNames.contains(base + "$" + i)) + i += 1 + base + "$" + i + } + usedNames += result + result + } + + private def tryOrRollback(body: CancelFun => TailRec[Tree])( + fallbackFun: () => TailRec[Tree]): TailRec[Tree] = { + if (disableOptimisticOptimizations) { + fallbackFun() + } else { + val trampolineId = curTrampolineId + val savedUsedLocalNames = usedLocalNames.toSet + val savedUsedLabelNames = usedLabelNames.toSet + val savedStates = statesInUse.map(_.makeBackup()) + + body { () => + throw new RollbackException(trampolineId, savedUsedLocalNames, + savedUsedLabelNames, savedStates, fallbackFun) + } + } + } + + private def isSubclass(lhs: String, rhs: String): Boolean = + getAncestorsOf(lhs).contains(rhs) + + private val isSubclassFun = isSubclass _ + private def isSubtype(lhs: Type, rhs: Type): Boolean = + Types.isSubtype(lhs, rhs)(isSubclassFun) + + /** Transforms a statement. + * + * For valid expression trees, it is always the case that + * {{{ + * transformStat(tree) + * === + * pretransformExpr(tree)(finishTransformStat) + * }}} + */ + private def transformStat(tree: Tree)(implicit scope: Scope): Tree = + transform(tree, isStat = true) + + /** Transforms an expression. + * + * It is always the case that + * {{{ + * transformExpr(tree) + * === + * pretransformExpr(tree)(finishTransformExpr) + * }}} + */ + private def transformExpr(tree: Tree)(implicit scope: Scope): Tree = + transform(tree, isStat = false) + + /** Transforms a tree. */ + private def transform(tree: Tree, isStat: Boolean)( + implicit scope: Scope): Tree = { + + @inline implicit def pos = tree.pos + val result = tree match { + // Definitions + + case VarDef(_, _, _, rhs) => + /* A local var that is last (or alone) in its block is not terribly + * useful. Get rid of it. + * (Non-last VarDefs in blocks are handled in transformBlock.) + */ + transformStat(rhs) + + // Control flow constructs + + case tree: Block => + transformBlock(tree, isStat) + + case Labeled(ident @ Ident(label, _), tpe, body) => + trampoline { + returnable(label, if (isStat) NoType else tpe, body, isStat, + usePreTransform = false)(finishTransform(isStat)) + } + + case Assign(lhs, rhs) => + val cont = { (preTransLhs: PreTransform) => + resolveLocalDef(preTransLhs) match { + case PreTransRecordTree(lhsTree, lhsOrigType, lhsCancelFun) => + val recordType = lhsTree.tpe.asInstanceOf[RecordType] + pretransformNoLocalDef(rhs) { + case PreTransRecordTree(rhsTree, rhsOrigType, rhsCancelFun) => + if (rhsTree.tpe != recordType || rhsOrigType != lhsOrigType) + lhsCancelFun() + TailCalls.done(Assign(lhsTree, rhsTree)) + case _ => + lhsCancelFun() + } + case PreTransTree(lhsTree, _) => + TailCalls.done(Assign(lhsTree, transformExpr(rhs))) + } + } + trampoline { + lhs match { + case lhs: Select => + pretransformSelectCommon(lhs, isLhsOfAssign = true)(cont) + case _ => + pretransformExpr(lhs)(cont) + } + } + + case Return(expr, optLabel) => + val optInfo = optLabel match { + case Some(Ident(label, _)) => + Some(scope.env.labelInfos(label)) + case None => + scope.env.labelInfos.get("") + } + optInfo.fold[Tree] { + Return(transformExpr(expr), None) + } { info => + val newOptLabel = Some(Ident(info.newName, None)) + if (!info.acceptRecords) { + val newExpr = transformExpr(expr) + info.returnedTypes.value ::= (newExpr.tpe, RefinedType(newExpr.tpe)) + Return(newExpr, newOptLabel) + } else trampoline { + pretransformNoLocalDef(expr) { texpr => + texpr match { + case PreTransRecordTree(newExpr, origType, cancelFun) => + info.returnedTypes.value ::= (newExpr.tpe, origType) + TailCalls.done(Return(newExpr, newOptLabel)) + case PreTransTree(newExpr, tpe) => + info.returnedTypes.value ::= (newExpr.tpe, tpe) + TailCalls.done(Return(newExpr, newOptLabel)) + } + } + } + } + + case If(cond, thenp, elsep) => + val newCond = transformExpr(cond) + newCond match { + case BooleanLiteral(condValue) => + if (condValue) transform(thenp, isStat) + else transform(elsep, isStat) + case _ => + val newThenp = transform(thenp, isStat) + val newElsep = transform(elsep, isStat) + val refinedType = + constrainedLub(newThenp.tpe, newElsep.tpe, tree.tpe) + foldIf(newCond, newThenp, newElsep)(refinedType) + } + + case While(cond, body, optLabel) => + val newCond = transformExpr(cond) + newCond match { + case BooleanLiteral(false) => Skip() + case _ => + optLabel match { + case None => + While(newCond, transformStat(body), None) + + case Some(labelIdent @ Ident(label, _)) => + val newLabel = freshLabelName(label) + val info = new LabelInfo(newLabel, acceptRecords = false) + While(newCond, { + val bodyScope = scope.withEnv( + scope.env.withLabelInfo(label, info)) + transformStat(body)(bodyScope) + }, Some(Ident(newLabel, None)(labelIdent.pos))) + } + } + + case DoWhile(body, cond, None) => + val newBody = transformStat(body) + val newCond = transformExpr(cond) + newCond match { + case BooleanLiteral(false) => newBody + case _ => DoWhile(newBody, newCond, None) + } + + case Try(block, errVar, EmptyTree, finalizer) => + val newBlock = transform(block, isStat) + val newFinalizer = transformStat(finalizer) + Try(newBlock, errVar, EmptyTree, newFinalizer)(newBlock.tpe) + + case Try(block, errVar @ Ident(name, originalName), handler, finalizer) => + val newBlock = transform(block, isStat) + + val newName = freshLocalName(name) + val newOriginalName = originalName.orElse(Some(name)) + val localDef = LocalDef(RefinedType(AnyType), true, + ReplaceWithVarRef(newName, newOriginalName, new SimpleState(true), None)) + val newHandler = { + val handlerScope = scope.withEnv(scope.env.withLocalDef(name, localDef)) + transform(handler, isStat)(handlerScope) + } + + val newFinalizer = transformStat(finalizer) + + val refinedType = constrainedLub(newBlock.tpe, newHandler.tpe, tree.tpe) + Try(newBlock, Ident(newName, newOriginalName)(errVar.pos), + newHandler, newFinalizer)(refinedType) + + case Throw(expr) => + Throw(transformExpr(expr)) + + case Continue(optLabel) => + val newOptLabel = optLabel map { label => + Ident(scope.env.labelInfos(label.name).newName, None)(label.pos) + } + Continue(newOptLabel) + + case Match(selector, cases, default) => + val newSelector = transformExpr(selector) + newSelector match { + case newSelector: Literal => + val body = cases collectFirst { + case (alts, body) if alts.exists(literal_===(_, newSelector)) => body + } getOrElse default + transform(body, isStat) + case _ => + Match(newSelector, + cases map (c => (c._1, transform(c._2, isStat))), + transform(default, isStat))(tree.tpe) + } + + // Scala expressions + + case New(cls, ctor, args) => + New(cls, ctor, args map transformExpr) + + case StoreModule(cls, value) => + StoreModule(cls, transformExpr(value)) + + case tree: Select => + trampoline { + pretransformSelectCommon(tree, isLhsOfAssign = false)( + finishTransform(isStat = false)) + } + + case tree: Apply => + trampoline { + pretransformApply(tree, isStat, usePreTransform = false)( + finishTransform(isStat)) + } + + case tree: StaticApply => + trampoline { + pretransformStaticApply(tree, isStat, usePreTransform = false)( + finishTransform(isStat)) + } + + case tree: TraitImplApply => + trampoline { + pretransformTraitImplApply(tree, isStat, usePreTransform = false)( + finishTransform(isStat)) + } + + case tree @ UnaryOp(_, arg) => + if (isStat) transformStat(arg) + else transformUnaryOp(tree) + + case tree @ BinaryOp(op, lhs, rhs) => + if (isStat) Block(transformStat(lhs), transformStat(rhs)) + else transformBinaryOp(tree) + + case NewArray(tpe, lengths) => + NewArray(tpe, lengths map transformExpr) + + case ArrayValue(tpe, elems) => + ArrayValue(tpe, elems map transformExpr) + + case ArrayLength(array) => + ArrayLength(transformExpr(array)) + + case ArraySelect(array, index) => + ArraySelect(transformExpr(array), transformExpr(index))(tree.tpe) + + case RecordValue(tpe, elems) => + RecordValue(tpe, elems map transformExpr) + + case IsInstanceOf(expr, ClassType(ObjectClass)) => + transformExpr(BinaryOp(BinaryOp.!==, expr, Null())) + + case IsInstanceOf(expr, tpe) => + trampoline { + pretransformExpr(expr) { texpr => + val result = { + if (isSubtype(texpr.tpe.base, tpe)) { + if (texpr.tpe.isNullable) + BinaryOp(BinaryOp.!==, finishTransformExpr(texpr), Null()) + else + Block(finishTransformStat(texpr), BooleanLiteral(true)) + } else { + if (texpr.tpe.isExact) + Block(finishTransformStat(texpr), BooleanLiteral(false)) + else + IsInstanceOf(finishTransformExpr(texpr), tpe) + } + } + TailCalls.done(result) + } + } + + case AsInstanceOf(expr, ClassType(ObjectClass)) => + transformExpr(expr) + + case AsInstanceOf(expr, cls) => + trampoline { + pretransformExpr(tree)(finishTransform(isStat)) + } + + case Unbox(arg, charCode) => + trampoline { + pretransformExpr(arg) { targ => + foldUnbox(targ, charCode)(finishTransform(isStat)) + } + } + + case GetClass(expr) => + GetClass(transformExpr(expr)) + + // JavaScript expressions + + case JSNew(ctor, args) => + JSNew(transformExpr(ctor), args map transformExpr) + + case JSDotSelect(qualifier, item) => + JSDotSelect(transformExpr(qualifier), item) + + case JSBracketSelect(qualifier, item) => + JSBracketSelect(transformExpr(qualifier), transformExpr(item)) + + case tree: JSFunctionApply => + trampoline { + pretransformJSFunctionApply(tree, isStat, usePreTransform = false)( + finishTransform(isStat)) + } + + case JSDotMethodApply(receiver, method, args) => + JSDotMethodApply(transformExpr(receiver), method, + args map transformExpr) + + case JSBracketMethodApply(receiver, method, args) => + JSBracketMethodApply(transformExpr(receiver), transformExpr(method), + args map transformExpr) + + case JSDelete(JSDotSelect(obj, prop)) => + JSDelete(JSDotSelect(transformExpr(obj), prop)) + + case JSDelete(JSBracketSelect(obj, prop)) => + JSDelete(JSBracketSelect(transformExpr(obj), transformExpr(prop))) + + case JSUnaryOp(op, lhs) => + JSUnaryOp(op, transformExpr(lhs)) + + case JSBinaryOp(op, lhs, rhs) => + JSBinaryOp(op, transformExpr(lhs), transformExpr(rhs)) + + case JSArrayConstr(items) => + JSArrayConstr(items map transformExpr) + + case JSObjectConstr(fields) => + JSObjectConstr(fields map { + case (name, value) => (name, transformExpr(value)) + }) + + // Atomic expressions + + case _:VarRef | _:This => + trampoline { + pretransformExpr(tree)(finishTransform(isStat)) + } + + case Closure(captureParams, params, body, captureValues) => + transformClosureCommon(captureParams, params, body, + captureValues.map(transformExpr)) + + // Trees that need not be transformed + + case _:Skip | _:Debugger | _:LoadModule | + _:JSEnvInfo | _:Literal | EmptyTree => + tree + } + + if (isStat) keepOnlySideEffects(result) + else result + } + + private def transformClosureCommon(captureParams: List[ParamDef], + params: List[ParamDef], body: Tree, newCaptureValues: List[Tree])( + implicit pos: Position): Closure = { + + val (allNewParams, newBody) = + transformIsolatedBody(None, AnyType, captureParams ++ params, AnyType, body) + val (newCaptureParams, newParams) = + allNewParams.splitAt(captureParams.size) + + Closure(newCaptureParams, newParams, newBody, newCaptureValues) + } + + private def transformBlock(tree: Block, isStat: Boolean)( + implicit scope: Scope): Tree = { + def transformList(stats: List[Tree])( + implicit scope: Scope): Tree = stats match { + case last :: Nil => + transform(last, isStat) + + case (VarDef(Ident(name, originalName), vtpe, mutable, rhs)) :: rest => + trampoline { + pretransformExpr(rhs) { trhs => + withBinding(Binding(name, originalName, vtpe, mutable, trhs)) { + (restScope, cont1) => + val newRest = transformList(rest)(restScope) + cont1(PreTransTree(newRest, RefinedType(newRest.tpe))) + } (finishTransform(isStat)) + } + } + + case stat :: rest => + val transformedStat = transformStat(stat) + if (transformedStat.tpe == NothingType) transformedStat + else Block(transformedStat, transformList(rest))(stat.pos) + + case Nil => // silence the exhaustivity warning in a sensible way + Skip()(tree.pos) + } + transformList(tree.stats)(scope) + } + + /** Pretransforms a list of trees as a list of [[PreTransform]]s. + * This is a convenience method to use pretransformExpr on a list. + */ + private def pretransformExprs(trees: List[Tree])( + cont: List[PreTransform] => TailRec[Tree])( + implicit scope: Scope): TailRec[Tree] = { + trees match { + case first :: rest => + pretransformExpr(first) { tfirst => + pretransformExprs(rest) { trest => + cont(tfirst :: trest) + } + } + + case Nil => + cont(Nil) + } + } + + /** Pretransforms two trees as a pair of [[PreTransform]]s. + * This is a convenience method to use pretransformExpr on two trees. + */ + private def pretransformExprs(tree1: Tree, tree2: Tree)( + cont: (PreTransform, PreTransform) => TailRec[Tree])( + implicit scope: Scope): TailRec[Tree] = { + pretransformExpr(tree1) { ttree1 => + pretransformExpr(tree2) { ttree2 => + cont(ttree1, ttree2) + } + } + } + + /** Pretransforms a tree and a list of trees as [[PreTransform]]s. + * This is a convenience method to use pretransformExpr. + */ + private def pretransformExprs(first: Tree, rest: List[Tree])( + cont: (PreTransform, List[PreTransform]) => TailRec[Tree])( + implicit scope: Scope): TailRec[Tree] = { + pretransformExpr(first) { tfirst => + pretransformExprs(rest) { trest => + cont(tfirst, trest) + } + } + } + + /** Pretransforms a tree to get a refined type while avoiding to force + * things we might be able to optimize by folding and aliasing. + */ + private def pretransformExpr(tree: Tree)(cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = tailcall { + @inline implicit def pos = tree.pos + + tree match { + case tree: Block => + pretransformBlock(tree)(cont) + + case VarRef(Ident(name, _), _) => + val localDef = scope.env.localDefs.getOrElse(name, + sys.error(s"Cannot find local def '$name' at $pos\n" + + s"While optimizing $myself\n" + + s"Env is ${scope.env}\nInlining ${scope.implsBeingInlined}")) + cont(PreTransLocalDef(localDef)) + + case This() => + val localDef = scope.env.localDefs.getOrElse("this", + sys.error(s"Found invalid 'this' at $pos\n" + + s"While optimizing $myself\n" + + s"Env is ${scope.env}\nInlining ${scope.implsBeingInlined}")) + cont(PreTransLocalDef(localDef)) + + case If(cond, thenp, elsep) => + val newCond = transformExpr(cond) + newCond match { + case BooleanLiteral(condValue) => + if (condValue) pretransformExpr(thenp)(cont) + else pretransformExpr(elsep)(cont) + case _ => + tryOrRollback { cancelFun => + pretransformNoLocalDef(thenp) { tthenp => + pretransformNoLocalDef(elsep) { telsep => + (tthenp, telsep) match { + case (PreTransRecordTree(thenTree, thenOrigType, thenCancelFun), + PreTransRecordTree(elseTree, elseOrigType, elseCancelFun)) => + val commonType = + if (thenTree.tpe == elseTree.tpe && + thenOrigType == elseOrigType) thenTree.tpe + else cancelFun() + val refinedOrigType = + constrainedLub(thenOrigType, elseOrigType, tree.tpe) + cont(PreTransRecordTree( + If(newCond, thenTree, elseTree)(commonType), + refinedOrigType, + cancelFun)) + + case (PreTransRecordTree(thenTree, thenOrigType, thenCancelFun), _) + if telsep.tpe.isNothingType => + cont(PreTransRecordTree( + If(newCond, thenTree, finishTransformExpr(telsep))(thenTree.tpe), + thenOrigType, + thenCancelFun)) + + case (_, PreTransRecordTree(elseTree, elseOrigType, elseCancelFun)) + if tthenp.tpe.isNothingType => + cont(PreTransRecordTree( + If(newCond, finishTransformExpr(tthenp), elseTree)(elseTree.tpe), + elseOrigType, + elseCancelFun)) + + case _ => + val newThenp = finishTransformExpr(tthenp) + val newElsep = finishTransformExpr(telsep) + val refinedType = + constrainedLub(newThenp.tpe, newElsep.tpe, tree.tpe) + cont(PreTransTree( + foldIf(newCond, newThenp, newElsep)(refinedType))) + } + } + } + } { () => + val newThenp = transformExpr(thenp) + val newElsep = transformExpr(elsep) + val refinedType = + constrainedLub(newThenp.tpe, newElsep.tpe, tree.tpe) + cont(PreTransTree( + foldIf(newCond, newThenp, newElsep)(refinedType))) + } + } + + case Match(selector, cases, default) => + val newSelector = transformExpr(selector) + newSelector match { + case newSelector: Literal => + val body = cases collectFirst { + case (alts, body) if alts.exists(literal_===(_, newSelector)) => body + } getOrElse default + pretransformExpr(body)(cont) + case _ => + cont(PreTransTree(Match(newSelector, + cases map (c => (c._1, transformExpr(c._2))), + transformExpr(default))(tree.tpe))) + } + + case Labeled(ident @ Ident(label, _), tpe, body) => + returnable(label, tpe, body, isStat = false, usePreTransform = true)(cont) + + case New(cls @ ClassType(className), ctor, args) => + tryNewInlineableClass(className) match { + case Some(initialValue) => + pretransformExprs(args) { targs => + tryOrRollback { cancelFun => + inlineClassConstructor( + new AllocationSite(tree), + cls, initialValue, ctor, targs, cancelFun)(cont) + } { () => + cont(PreTransTree( + New(cls, ctor, targs.map(finishTransformExpr)), + RefinedType(cls, isExact = true, isNullable = false))) + } + } + case None => + cont(PreTransTree( + New(cls, ctor, args.map(transformExpr)), + RefinedType(cls, isExact = true, isNullable = false))) + } + + case tree: Select => + pretransformSelectCommon(tree, isLhsOfAssign = false)(cont) + + case tree: Apply => + pretransformApply(tree, isStat = false, + usePreTransform = true)(cont) + + case tree: StaticApply => + pretransformStaticApply(tree, isStat = false, + usePreTransform = true)(cont) + + case tree: TraitImplApply => + pretransformTraitImplApply(tree, isStat = false, + usePreTransform = true)(cont) + + case tree: JSFunctionApply => + pretransformJSFunctionApply(tree, isStat = false, + usePreTransform = true)(cont) + + case AsInstanceOf(expr, tpe) => + pretransformExpr(expr) { texpr => + tpe match { + case ClassType(ObjectClass) => + cont(texpr) + case _ => + if (isSubtype(texpr.tpe.base, tpe)) { + cont(texpr) + } else { + cont(PreTransTree( + AsInstanceOf(finishTransformExpr(texpr), tpe))) + } + } + } + + case Closure(captureParams, params, body, captureValues) => + pretransformExprs(captureValues) { tcaptureValues => + tryOrRollback { cancelFun => + val captureBindings = for { + (ParamDef(Ident(name, origName), tpe, mutable), value) <- + captureParams zip tcaptureValues + } yield { + Binding(name, origName, tpe, mutable, value) + } + withNewLocalDefs(captureBindings) { (captureLocalDefs, cont1) => + val alreadyUsedState = new SimpleState[Boolean](false) + withState(alreadyUsedState) { + val replacement = TentativeClosureReplacement( + captureParams, params, body, captureLocalDefs, + alreadyUsedState, cancelFun) + val localDef = LocalDef( + RefinedType(AnyType, isExact = false, isNullable = false), + mutable = false, + replacement) + cont1(PreTransLocalDef(localDef)) + } + } (cont) + } { () => + val newClosure = transformClosureCommon(captureParams, params, body, + tcaptureValues.map(finishTransformExpr)) + cont(PreTransTree( + newClosure, + RefinedType(AnyType, isExact = false, isNullable = false))) + } + } + + case _ => + val result = transformExpr(tree) + cont(PreTransTree(result, RefinedType(result.tpe))) + } + } + + private def pretransformBlock(tree: Block)( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + def pretransformList(stats: List[Tree])( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = stats match { + case last :: Nil => + pretransformExpr(last)(cont) + + case (VarDef(Ident(name, originalName), vtpe, mutable, rhs)) :: rest => + pretransformExpr(rhs) { trhs => + withBinding(Binding(name, originalName, vtpe, mutable, trhs)) { + (restScope, cont1) => + pretransformList(rest)(cont1)(restScope) + } (cont) + } + + case stat :: rest => + implicit val pos = tree.pos + val transformedStat = transformStat(stat) + transformedStat match { + case Skip() => + pretransformList(rest)(cont) + case _ => + if (transformedStat.tpe == NothingType) + cont(PreTransTree(transformedStat, RefinedType.Nothing)) + else { + pretransformList(rest) { trest => + cont(PreTransBlock(transformedStat :: Nil, trest)) + } + } + } + + case Nil => // silence the exhaustivity warning in a sensible way + TailCalls.done(Skip()(tree.pos)) + } + pretransformList(tree.stats)(cont)(scope) + } + + private def pretransformSelectCommon(tree: Select, isLhsOfAssign: Boolean)( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + val Select(qualifier, item, mutable) = tree + pretransformExpr(qualifier) { preTransQual => + pretransformSelectCommon(tree.tpe, preTransQual, item, mutable, + isLhsOfAssign)(cont)(scope, tree.pos) + } + } + + private def pretransformSelectCommon(expectedType: Type, + preTransQual: PreTransform, item: Ident, mutable: Boolean, + isLhsOfAssign: Boolean)( + cont: PreTransCont)( + implicit scope: Scope, pos: Position): TailRec[Tree] = { + preTransQual match { + case PreTransLocalDef(LocalDef(_, _, + InlineClassBeingConstructedReplacement(fieldLocalDefs, cancelFun))) => + val fieldLocalDef = fieldLocalDefs(item.name) + if (!isLhsOfAssign || fieldLocalDef.mutable) { + cont(PreTransLocalDef(fieldLocalDef)) + } else { + /* This is an assignment to an immutable field of a inlineable class + * being constructed, but that does not appear at the "top-level" of + * one of its constructors. We cannot handle those, so we cancel. + * (Assignments at the top-level are normal initializations of these + * fields, and are transformed as vals in inlineClassConstructor.) + */ + cancelFun() + } + case PreTransLocalDef(LocalDef(_, _, + InlineClassInstanceReplacement(_, fieldLocalDefs, cancelFun))) => + val fieldLocalDef = fieldLocalDefs(item.name) + if (!isLhsOfAssign || fieldLocalDef.mutable) { + cont(PreTransLocalDef(fieldLocalDef)) + } else { + /* In an ideal world, this should not happen (assigning to an + * immutable field of an already constructed object). However, since + * we cannot IR-check that this does not happen (see #1021), this is + * effectively allowed by the IR spec. We are therefore not allowed + * to crash. We cancel instead. This will become an actual field + * (rather than an optimized local val) which is not considered pure + * (for that same reason). + */ + cancelFun() + } + case _ => + resolveLocalDef(preTransQual) match { + case PreTransRecordTree(newQual, origType, cancelFun) => + val recordType = newQual.tpe.asInstanceOf[RecordType] + val field = recordType.findField(item.name) + val sel = Select(newQual, item, mutable)(field.tpe) + sel.tpe match { + case _: RecordType => + cont(PreTransRecordTree(sel, RefinedType(expectedType), cancelFun)) + case _ => + cont(PreTransTree(sel, RefinedType(sel.tpe))) + } + + case PreTransTree(newQual, _) => + cont(PreTransTree(Select(newQual, item, mutable)(expectedType), + RefinedType(expectedType))) + } + } + } + + /** Resolves any LocalDef in a [[PreTransform]]. */ + private def resolveLocalDef(preTrans: PreTransform): PreTransGenTree = { + implicit val pos = preTrans.pos + preTrans match { + case PreTransBlock(stats, result) => + resolveLocalDef(result) match { + case PreTransRecordTree(tree, tpe, cancelFun) => + PreTransRecordTree(Block(stats :+ tree), tpe, cancelFun) + case PreTransTree(tree, tpe) => + PreTransTree(Block(stats :+ tree), tpe) + } + + case PreTransLocalDef(localDef @ LocalDef(tpe, mutable, replacement)) => + replacement match { + case ReplaceWithRecordVarRef(name, originalName, + recordType, used, cancelFun) => + used.value = true + PreTransRecordTree( + VarRef(Ident(name, originalName), mutable)(recordType), + tpe, cancelFun) + + case InlineClassInstanceReplacement(recordType, fieldLocalDefs, cancelFun) => + if (!isImmutableType(recordType)) + cancelFun() + PreTransRecordTree( + RecordValue(recordType, recordType.fields.map( + f => fieldLocalDefs(f.name).newReplacement)), + tpe, cancelFun) + + case _ => + PreTransTree(localDef.newReplacement, localDef.tpe) + } + + case preTrans: PreTransGenTree => + preTrans + } + } + + /** Combines pretransformExpr and resolveLocalDef in one convenience method. */ + private def pretransformNoLocalDef(tree: Tree)( + cont: PreTransGenTree => TailRec[Tree])( + implicit scope: Scope): TailRec[Tree] = { + pretransformExpr(tree) { ttree => + cont(resolveLocalDef(ttree)) + } + } + + /** Finishes a pretransform, either a statement or an expression. */ + private def finishTransform(isStat: Boolean): PreTransCont = { preTrans => + TailCalls.done { + if (isStat) finishTransformStat(preTrans) + else finishTransformExpr(preTrans) + } + } + + /** Finishes an expression pretransform to get a normal [[Tree]]. + * This method (together with finishTransformStat) must not be called more + * than once per pretransform and per translation. + * By "per translation", we mean in an alternative path through + * `tryOrRollback`. It could still be called several times as long as + * it is once in the 'try' part and once in the 'fallback' part. + */ + private def finishTransformExpr(preTrans: PreTransform): Tree = { + implicit val pos = preTrans.pos + preTrans match { + case PreTransBlock(stats, result) => + Block(stats :+ finishTransformExpr(result)) + case PreTransLocalDef(localDef) => + localDef.newReplacement + case PreTransRecordTree(_, _, cancelFun) => + cancelFun() + case PreTransTree(tree, _) => + tree + } + } + + /** Finishes a statement pretransform to get a normal [[Tree]]. + * This method (together with finishTransformExpr) must not be called more + * than once per pretransform and per translation. + * By "per translation", we mean in an alternative path through + * `tryOrRollback`. It could still be called several times as long as + * it is once in the 'try' part and once in the 'fallback' part. + */ + private def finishTransformStat(stat: PreTransform): Tree = stat match { + case PreTransBlock(stats, result) => + Block(stats :+ finishTransformStat(result))(stat.pos) + case PreTransLocalDef(_) => + Skip()(stat.pos) + case PreTransRecordTree(tree, _, _) => + keepOnlySideEffects(tree) + case PreTransTree(tree, _) => + keepOnlySideEffects(tree) + } + + /** Keeps only the side effects of a Tree (overapproximation). */ + private def keepOnlySideEffects(stat: Tree): Tree = stat match { + case _:VarRef | _:This | _:Literal => + Skip()(stat.pos) + case Block(init :+ last) => + Block(init :+ keepOnlySideEffects(last))(stat.pos) + case LoadModule(ClassType(moduleClassName)) => + if (hasElidableModuleAccessor(moduleClassName)) Skip()(stat.pos) + else stat + case Select(LoadModule(ClassType(moduleClassName)), _, _) => + if (hasElidableModuleAccessor(moduleClassName)) Skip()(stat.pos) + else stat + case Closure(_, _, _, captureValues) => + Block(captureValues.map(keepOnlySideEffects))(stat.pos) + case UnaryOp(_, arg) => + keepOnlySideEffects(arg) + case If(cond, thenp, elsep) => + (keepOnlySideEffects(thenp), keepOnlySideEffects(elsep)) match { + case (Skip(), Skip()) => keepOnlySideEffects(cond) + case (newThenp, newElsep) => If(cond, newThenp, newElsep)(NoType)(stat.pos) + } + case BinaryOp(_, lhs, rhs) => + Block(keepOnlySideEffects(lhs), keepOnlySideEffects(rhs))(stat.pos) + case RecordValue(_, elems) => + Block(elems.map(keepOnlySideEffects))(stat.pos) + case _ => + stat + } + + private def pretransformApply(tree: Apply, isStat: Boolean, + usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + val Apply(receiver, methodIdent @ Ident(methodName, _), args) = tree + implicit val pos = tree.pos + + pretransformExpr(receiver) { treceiver => + def treeNotInlined0(transformedArgs: List[Tree]) = + cont(PreTransTree(Apply(finishTransformExpr(treceiver), methodIdent, + transformedArgs)(tree.tpe)(tree.pos), RefinedType(tree.tpe))) + + def treeNotInlined = treeNotInlined0(args.map(transformExpr)) + + treceiver.tpe.base match { + case NothingType => + cont(treceiver) + case NullType => + cont(PreTransTree(Block( + finishTransformStat(treceiver), + CallHelper("throwNullPointerException")(NothingType)))) + case _ => + if (isReflProxyName(methodName)) { + // Never inline reflective proxies + treeNotInlined + } else { + val cls = boxedClassForType(treceiver.tpe.base) + val impls = + if (treceiver.tpe.isExact) staticCall(cls, methodName).toList + else dynamicCall(cls, methodName) + val allocationSite = treceiver.tpe.allocationSite + if (impls.isEmpty || impls.exists(impl => + scope.implsBeingInlined((allocationSite, impl)))) { + // isEmpty could happen, have to leave it as is for the TypeError + treeNotInlined + } else if (impls.size == 1) { + val target = impls.head + pretransformExprs(args) { targs => + val intrinsicCode = getIntrinsicCode(target) + if (intrinsicCode >= 0) { + callIntrinsic(intrinsicCode, Some(treceiver), targs, + isStat, usePreTransform)(cont) + } else if (target.inlineable || shouldInlineBecauseOfArgs(treceiver :: targs)) { + inline(allocationSite, Some(treceiver), targs, target, + isStat, usePreTransform)(cont) + } else { + treeNotInlined0(targs.map(finishTransformExpr)) + } + } + } else { + if (impls.forall(_.isTraitImplForwarder)) { + val reference = impls.head + val TraitImplApply(ClassType(traitImpl), Ident(methodName, _), _) = + getMethodBody(reference).body + if (!impls.tail.forall(getMethodBody(_).body match { + case TraitImplApply(ClassType(`traitImpl`), + Ident(`methodName`, _), _) => true + case _ => false + })) { + // Not all calling the same method in the same trait impl + treeNotInlined + } else { + pretransformExprs(args) { targs => + inline(allocationSite, Some(treceiver), targs, reference, + isStat, usePreTransform)(cont) + } + } + } else { + // TODO? Inline multiple non-trait-impl-forwarder with the exact same body? + treeNotInlined + } + } + } + } + } + } + + private def boxedClassForType(tpe: Type): String = (tpe: @unchecked) match { + case ClassType(cls) => cls + case AnyType => Definitions.ObjectClass + case UndefType => Definitions.BoxedUnitClass + case BooleanType => Definitions.BoxedBooleanClass + case IntType => Definitions.BoxedIntegerClass + case LongType => Definitions.BoxedLongClass + case FloatType => Definitions.BoxedFloatClass + case DoubleType => Definitions.BoxedDoubleClass + case StringType => Definitions.StringClass + case ArrayType(_, _) => Definitions.ObjectClass + } + + private def pretransformStaticApply(tree: StaticApply, isStat: Boolean, + usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + val StaticApply(receiver, clsType @ ClassType(cls), + methodIdent @ Ident(methodName, _), args) = tree + implicit val pos = tree.pos + + def treeNotInlined0(transformedReceiver: Tree, transformedArgs: List[Tree]) = + cont(PreTransTree(StaticApply(transformedReceiver, clsType, + methodIdent, transformedArgs)(tree.tpe), RefinedType(tree.tpe))) + + def treeNotInlined = + treeNotInlined0(transformExpr(receiver), args.map(transformExpr)) + + if (isReflProxyName(methodName)) { + // Never inline reflective proxies + treeNotInlined + } else { + val optTarget = staticCall(cls, methodName) + if (optTarget.isEmpty) { + // just in case + treeNotInlined + } else { + val target = optTarget.get + pretransformExprs(receiver, args) { (treceiver, targs) => + val intrinsicCode = getIntrinsicCode(target) + if (intrinsicCode >= 0) { + callIntrinsic(intrinsicCode, Some(treceiver), targs, + isStat, usePreTransform)(cont) + } else { + val shouldInline = + target.inlineable || shouldInlineBecauseOfArgs(treceiver :: targs) + val allocationSite = treceiver.tpe.allocationSite + val beingInlined = + scope.implsBeingInlined((allocationSite, target)) + + if (shouldInline && !beingInlined) { + inline(allocationSite, Some(treceiver), targs, target, + isStat, usePreTransform)(cont) + } else { + treeNotInlined0(finishTransformExpr(treceiver), + targs.map(finishTransformExpr)) + } + } + } + } + } + } + + private def pretransformTraitImplApply(tree: TraitImplApply, isStat: Boolean, + usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + val TraitImplApply(implType @ ClassType(impl), + methodIdent @ Ident(methodName, _), args) = tree + implicit val pos = tree.pos + + def treeNotInlined0(transformedArgs: List[Tree]) = + cont(PreTransTree(TraitImplApply(implType, methodIdent, + transformedArgs)(tree.tpe), RefinedType(tree.tpe))) + + def treeNotInlined = treeNotInlined0(args.map(transformExpr)) + + val optTarget = traitImplCall(impl, methodName) + if (optTarget.isEmpty) { + // just in case + treeNotInlined + } else { + val target = optTarget.get + pretransformExprs(args) { targs => + val intrinsicCode = getIntrinsicCode(target) + if (intrinsicCode >= 0) { + callIntrinsic(intrinsicCode, None, targs, + isStat, usePreTransform)(cont) + } else { + val shouldInline = + target.inlineable || shouldInlineBecauseOfArgs(targs) + val allocationSite = targs.headOption.flatMap(_.tpe.allocationSite) + val beingInlined = + scope.implsBeingInlined((allocationSite, target)) + + if (shouldInline && !beingInlined) { + inline(allocationSite, None, targs, target, + isStat, usePreTransform)(cont) + } else { + treeNotInlined0(targs.map(finishTransformExpr)) + } + } + } + } + } + + private def pretransformJSFunctionApply(tree: JSFunctionApply, + isStat: Boolean, usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope, pos: Position): TailRec[Tree] = { + val JSFunctionApply(fun, args) = tree + implicit val pos = tree.pos + + pretransformExpr(fun) { tfun => + tfun match { + case PreTransLocalDef(LocalDef(_, false, + closure @ TentativeClosureReplacement( + captureParams, params, body, captureLocalDefs, + alreadyUsed, cancelFun))) if !alreadyUsed.value => + alreadyUsed.value = true + pretransformExprs(args) { targs => + inlineBody( + Some(PreTransTree(Undefined())), // `this` is `undefined` + captureParams ++ params, AnyType, body, + captureLocalDefs.map(PreTransLocalDef(_)) ++ targs, isStat, + usePreTransform)(cont) + } + + case _ => + cont(PreTransTree( + JSFunctionApply(finishTransformExpr(tfun), args.map(transformExpr)))) + } + } + } + + private def shouldInlineBecauseOfArgs( + receiverAndArgs: List[PreTransform]): Boolean = { + def isLikelyOptimizable(arg: PreTransform): Boolean = arg match { + case PreTransBlock(_, result) => + isLikelyOptimizable(result) + + case PreTransLocalDef(localDef) => + localDef.replacement match { + case TentativeClosureReplacement(_, _, _, _, _, _) => true + case ReplaceWithRecordVarRef(_, _, _, _, _) => true + case InlineClassBeingConstructedReplacement(_, _) => true + case InlineClassInstanceReplacement(_, _, _) => true + case _ => false + } + + case PreTransRecordTree(_, _, _) => + true + + case _ => + arg.tpe.base match { + case ClassType("s_Predef$$less$colon$less" | "s_Predef$$eq$colon$eq") => + true + case _ => + false + } + } + receiverAndArgs.exists(isLikelyOptimizable) + } + + private def inline(allocationSite: Option[AllocationSite], + optReceiver: Option[PreTransform], + args: List[PreTransform], target: MethodID, isStat: Boolean, + usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope, pos: Position): TailRec[Tree] = { + + attemptedInlining += target + + val MethodDef(_, formals, resultType, body) = getMethodBody(target) + + body match { + case Skip() => + assert(isStat, "Found Skip() in expression position") + cont(PreTransTree( + Block((optReceiver ++: args).map(finishTransformStat)), + RefinedType.NoRefinedType)) + + case _: Literal => + cont(PreTransTree( + Block((optReceiver ++: args).map(finishTransformStat) :+ body), + RefinedType(body.tpe))) + + case This() if args.isEmpty => + assert(optReceiver.isDefined, + "There was a This(), there should be a receiver") + cont(optReceiver.get) + + case Select(This(), field, mutable) if formals.isEmpty => + assert(optReceiver.isDefined, + "There was a This(), there should be a receiver") + pretransformSelectCommon(body.tpe, optReceiver.get, field, mutable, + isLhsOfAssign = false)(cont) + + case Assign(lhs @ Select(This(), field, mutable), VarRef(Ident(rhsName, _), _)) + if formals.size == 1 && formals.head.name.name == rhsName => + assert(isStat, "Found Assign in expression position") + assert(optReceiver.isDefined, + "There was a This(), there should be a receiver") + pretransformSelectCommon(lhs.tpe, optReceiver.get, field, mutable, + isLhsOfAssign = true) { preTransLhs => + // TODO Support assignment of record + cont(PreTransTree( + Assign(finishTransformExpr(preTransLhs), + finishTransformExpr(args.head)), + RefinedType.NoRefinedType)) + } + + case _ => + val targetID = (allocationSite, target) + inlineBody(optReceiver, formals, resultType, body, args, isStat, + usePreTransform)(cont)(scope.inlining(targetID), pos) + } + } + + private def inlineBody(optReceiver: Option[PreTransform], + formals: List[ParamDef], resultType: Type, body: Tree, + args: List[PreTransform], isStat: Boolean, + usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope, pos: Position): TailRec[Tree] = tailcall { + + val optReceiverBinding = optReceiver map { receiver => + Binding("this", None, receiver.tpe.base, false, receiver) + } + + val argsBindings = for { + (ParamDef(Ident(name, originalName), tpe, mutable), arg) <- formals zip args + } yield { + Binding(name, originalName, tpe, mutable, arg) + } + + withBindings(optReceiverBinding ++: argsBindings) { (bodyScope, cont1) => + returnable("", resultType, body, isStat, usePreTransform)( + cont1)(bodyScope, pos) + } (cont) (scope.withEnv(OptEnv.Empty)) + } + + private def callIntrinsic(code: Int, optTReceiver: Option[PreTransform], + targs: List[PreTransform], isStat: Boolean, usePreTransform: Boolean)( + cont: PreTransCont)( + implicit pos: Position): TailRec[Tree] = { + + import Intrinsics._ + + implicit def string2ident(s: String): Ident = Ident(s, None) + + lazy val newArgs = targs.map(finishTransformExpr) + + @inline def contTree(result: Tree) = cont(PreTransTree(result)) + + @inline def StringClassType = ClassType(Definitions.StringClass) + + def asRTLong(arg: Tree): Tree = + AsInstanceOf(arg, ClassType(LongImpl.RuntimeLongClass)) + def firstArgAsRTLong: Tree = + asRTLong(newArgs.head) + + (code: @switch) match { + // java.lang.System + + case ArrayCopy => + assert(isStat, "System.arraycopy must be used in statement position") + contTree(CallHelper("systemArraycopy", newArgs)(NoType)) + case IdentityHashCode => + contTree(CallHelper("systemIdentityHashCode", newArgs)(IntType)) + + // scala.scalajs.runtime package object + + case PropertiesOf => + contTree(CallHelper("propertiesOf", newArgs)(AnyType)) + + // java.lang.Long + + case LongToString => + contTree(Apply(firstArgAsRTLong, "toString__T", Nil)(StringClassType)) + case LongCompare => + contTree(Apply(firstArgAsRTLong, "compareTo__sjsr_RuntimeLong__I", + List(asRTLong(newArgs(1))))(IntType)) + case LongBitCount => + contTree(Apply(firstArgAsRTLong, LongImpl.bitCount, Nil)(IntType)) + case LongSignum => + contTree(Apply(firstArgAsRTLong, LongImpl.signum, Nil)(LongType)) + case LongLeading0s => + contTree(Apply(firstArgAsRTLong, LongImpl.numberOfLeadingZeros, Nil)(IntType)) + case LongTrailing0s => + contTree(Apply(firstArgAsRTLong, LongImpl.numberOfTrailingZeros, Nil)(IntType)) + case LongToBinStr => + contTree(Apply(firstArgAsRTLong, LongImpl.toBinaryString, Nil)(StringClassType)) + case LongToHexStr => + contTree(Apply(firstArgAsRTLong, LongImpl.toHexString, Nil)(StringClassType)) + case LongToOctalStr => + contTree(Apply(firstArgAsRTLong, LongImpl.toOctalString, Nil)(StringClassType)) + + // TypedArray conversions + + case ByteArrayToInt8Array => + contTree(CallHelper("byteArray2TypedArray", newArgs)(AnyType)) + case ShortArrayToInt16Array => + contTree(CallHelper("shortArray2TypedArray", newArgs)(AnyType)) + case CharArrayToUint16Array => + contTree(CallHelper("charArray2TypedArray", newArgs)(AnyType)) + case IntArrayToInt32Array => + contTree(CallHelper("intArray2TypedArray", newArgs)(AnyType)) + case FloatArrayToFloat32Array => + contTree(CallHelper("floatArray2TypedArray", newArgs)(AnyType)) + case DoubleArrayToFloat64Array => + contTree(CallHelper("doubleArray2TypedArray", newArgs)(AnyType)) + + case Int8ArrayToByteArray => + contTree(CallHelper("typedArray2ByteArray", newArgs)(AnyType)) + case Int16ArrayToShortArray => + contTree(CallHelper("typedArray2ShortArray", newArgs)(AnyType)) + case Uint16ArrayToCharArray => + contTree(CallHelper("typedArray2CharArray", newArgs)(AnyType)) + case Int32ArrayToIntArray => + contTree(CallHelper("typedArray2IntArray", newArgs)(AnyType)) + case Float32ArrayToFloatArray => + contTree(CallHelper("typedArray2FloatArray", newArgs)(AnyType)) + case Float64ArrayToDoubleArray => + contTree(CallHelper("typedArray2DoubleArray", newArgs)(AnyType)) + } + } + + private def inlineClassConstructor(allocationSite: AllocationSite, + cls: ClassType, initialValue: RecordValue, + ctor: Ident, args: List[PreTransform], cancelFun: CancelFun)( + cont: PreTransCont)( + implicit scope: Scope, pos: Position): TailRec[Tree] = { + + val RecordValue(recordType, initialFieldValues) = initialValue + + pretransformExprs(initialFieldValues) { tinitialFieldValues => + val initialFieldBindings = for { + (RecordType.Field(name, originalName, tpe, mutable), value) <- + recordType.fields zip tinitialFieldValues + } yield { + Binding(name, originalName, tpe, mutable, value) + } + + withNewLocalDefs(initialFieldBindings) { (initialFieldLocalDefList, cont1) => + val fieldNames = initialValue.tpe.fields.map(_.name) + val initialFieldLocalDefs = + Map(fieldNames zip initialFieldLocalDefList: _*) + + inlineClassConstructorBody(allocationSite, initialFieldLocalDefs, + cls, cls, ctor, args, cancelFun) { (finalFieldLocalDefs, cont2) => + cont2(PreTransLocalDef(LocalDef( + RefinedType(cls, isExact = true, isNullable = false, + allocationSite = Some(allocationSite)), + mutable = false, + InlineClassInstanceReplacement(recordType, finalFieldLocalDefs, cancelFun)))) + } (cont1) + } (cont) + } + } + + private def inlineClassConstructorBody( + allocationSite: AllocationSite, + inputFieldsLocalDefs: Map[String, LocalDef], cls: ClassType, + ctorClass: ClassType, ctor: Ident, args: List[PreTransform], + cancelFun: CancelFun)( + buildInner: (Map[String, LocalDef], PreTransCont) => TailRec[Tree])( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = tailcall { + + val target = staticCall(ctorClass.className, ctor.name).getOrElse(cancelFun()) + val targetID = (Some(allocationSite), target) + if (scope.implsBeingInlined.contains(targetID)) + cancelFun() + + val MethodDef(_, formals, _, BlockOrAlone(stats, This())) = + getMethodBody(target) + + val argsBindings = for { + (ParamDef(Ident(name, originalName), tpe, mutable), arg) <- formals zip args + } yield { + Binding(name, originalName, tpe, mutable, arg) + } + + withBindings(argsBindings) { (bodyScope, cont1) => + val thisLocalDef = LocalDef( + RefinedType(cls, isExact = true, isNullable = false), false, + InlineClassBeingConstructedReplacement(inputFieldsLocalDefs, cancelFun)) + val statsScope = bodyScope.inlining(targetID).withEnv( + bodyScope.env.withLocalDef("this", thisLocalDef)) + inlineClassConstructorBodyList(allocationSite, thisLocalDef, + inputFieldsLocalDefs, cls, stats, cancelFun)( + buildInner)(cont1)(statsScope) + } (cont) (scope.withEnv(OptEnv.Empty)) + } + + private def inlineClassConstructorBodyList( + allocationSite: AllocationSite, + thisLocalDef: LocalDef, inputFieldsLocalDefs: Map[String, LocalDef], + cls: ClassType, stats: List[Tree], cancelFun: CancelFun)( + buildInner: (Map[String, LocalDef], PreTransCont) => TailRec[Tree])( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + stats match { + case This() :: rest => + inlineClassConstructorBodyList(allocationSite, thisLocalDef, + inputFieldsLocalDefs, cls, rest, cancelFun)(buildInner)(cont) + + case Assign(s @ Select(ths: This, + Ident(fieldName, fieldOrigName), false), value) :: rest => + pretransformExpr(value) { tvalue => + withNewLocalDef(Binding(fieldName, fieldOrigName, s.tpe, false, + tvalue)) { (localDef, cont1) => + if (localDef.contains(thisLocalDef)) { + /* Uh oh, there is a `val x = ...this...`. We can't keep it, + * because this field will not be updated with `newThisLocalDef`. + */ + cancelFun() + } + val newFieldsLocalDefs = + inputFieldsLocalDefs.updated(fieldName, localDef) + val newThisLocalDef = LocalDef( + RefinedType(cls, isExact = true, isNullable = false), false, + InlineClassBeingConstructedReplacement(newFieldsLocalDefs, cancelFun)) + val restScope = scope.withEnv(scope.env.withLocalDef( + "this", newThisLocalDef)) + inlineClassConstructorBodyList(allocationSite, + newThisLocalDef, newFieldsLocalDefs, cls, rest, cancelFun)( + buildInner)(cont1)(restScope) + } (cont) + } + + /* if (cond) + * throw e + * else + * this.outer = value + * + * becomes + * + * this.outer = + * if (cond) throw e + * else value + * + * Typical shape of initialization of outer pointer of inner classes. + */ + case If(cond, th: Throw, + Assign(Select(This(), _, false), value)) :: rest => + // work around a bug of the compiler (these should be @-bindings) + val stat = stats.head.asInstanceOf[If] + val ass = stat.elsep.asInstanceOf[Assign] + val lhs = ass.lhs + inlineClassConstructorBodyList(allocationSite, thisLocalDef, + inputFieldsLocalDefs, cls, + Assign(lhs, If(cond, th, value)(lhs.tpe)(stat.pos))(ass.pos) :: rest, + cancelFun)(buildInner)(cont) + + case StaticApply(ths: This, superClass, superCtor, args) :: rest + if isConstructorName(superCtor.name) => + pretransformExprs(args) { targs => + inlineClassConstructorBody(allocationSite, inputFieldsLocalDefs, + cls, superClass, superCtor, targs, + cancelFun) { (outputFieldsLocalDefs, cont1) => + val newThisLocalDef = LocalDef( + RefinedType(cls, isExact = true, isNullable = false), false, + InlineClassBeingConstructedReplacement(outputFieldsLocalDefs, cancelFun)) + val restScope = scope.withEnv(scope.env.withLocalDef( + "this", newThisLocalDef)) + inlineClassConstructorBodyList(allocationSite, + newThisLocalDef, outputFieldsLocalDefs, + cls, rest, cancelFun)(buildInner)(cont1)(restScope) + } (cont) + } + + case VarDef(Ident(name, originalName), tpe, mutable, rhs) :: rest => + pretransformExpr(rhs) { trhs => + withBinding(Binding(name, originalName, tpe, mutable, trhs)) { (restScope, cont1) => + inlineClassConstructorBodyList(allocationSite, + thisLocalDef, inputFieldsLocalDefs, + cls, rest, cancelFun)(buildInner)(cont1)(restScope) + } (cont) + } + + case stat :: rest => + val transformedStat = transformStat(stat) + transformedStat match { + case Skip() => + inlineClassConstructorBodyList(allocationSite, + thisLocalDef, inputFieldsLocalDefs, + cls, rest, cancelFun)(buildInner)(cont) + case _ => + if (transformedStat.tpe == NothingType) + cont(PreTransTree(transformedStat, RefinedType.Nothing)) + else { + inlineClassConstructorBodyList(allocationSite, + thisLocalDef, inputFieldsLocalDefs, + cls, rest, cancelFun) { (outputFieldsLocalDefs, cont1) => + buildInner(outputFieldsLocalDefs, { tinner => + cont1(PreTransBlock(transformedStat :: Nil, tinner)) + }) + }(cont) + } + } + + case Nil => + buildInner(inputFieldsLocalDefs, cont) + } + } + + private def foldIf(cond: Tree, thenp: Tree, elsep: Tree)(tpe: Type)( + implicit pos: Position): Tree = { + import BinaryOp._ + + @inline def default = If(cond, thenp, elsep)(tpe) + cond match { + case BooleanLiteral(v) => + if (v) thenp + else elsep + + case _ => + @inline def negCond = foldUnaryOp(UnaryOp.Boolean_!, cond) + if (thenp.tpe == BooleanType && elsep.tpe == BooleanType) { + (cond, thenp, elsep) match { + case (_, BooleanLiteral(t), BooleanLiteral(e)) => + if (t == e) Block(keepOnlySideEffects(cond), thenp) + else if (t) cond + else negCond + + case (_, BooleanLiteral(false), _) => + foldIf(negCond, elsep, BooleanLiteral(false))(tpe) // canonical && form + case (_, _, BooleanLiteral(true)) => + foldIf(negCond, BooleanLiteral(true), thenp)(tpe) // canonical || form + + /* if (lhs === null) rhs === null else lhs === rhs + * -> lhs === rhs + * This is the typical shape of a lhs == rhs test where + * the equals() method has been inlined as a reference + * equality test. + */ + case (BinaryOp(BinaryOp.===, VarRef(lhsIdent, _), Null()), + BinaryOp(BinaryOp.===, VarRef(rhsIdent, _), Null()), + BinaryOp(BinaryOp.===, VarRef(lhsIdent2, _), VarRef(rhsIdent2, _))) + if lhsIdent2 == lhsIdent && rhsIdent2 == rhsIdent => + elsep + + // Example: (x > y) || (x == y) -> (x >= y) + case (BinaryOp(op1 @ (Num_== | Num_!= | Num_< | Num_<= | Num_> | Num_>=), l1, r1), + BooleanLiteral(true), + BinaryOp(op2 @ (Num_== | Num_!= | Num_< | Num_<= | Num_> | Num_>=), l2, r2)) + if ((l1.isInstanceOf[Literal] || l1.isInstanceOf[VarRef]) && + (r1.isInstanceOf[Literal] || r1.isInstanceOf[VarRef]) && + (l1 == l2 && r1 == r2)) => + val canBeEqual = + ((op1 == Num_==) || (op1 == Num_<=) || (op1 == Num_>=)) || + ((op2 == Num_==) || (op2 == Num_<=) || (op2 == Num_>=)) + val canBeLessThan = + ((op1 == Num_!=) || (op1 == Num_<) || (op1 == Num_<=)) || + ((op2 == Num_!=) || (op2 == Num_<) || (op2 == Num_<=)) + val canBeGreaterThan = + ((op1 == Num_!=) || (op1 == Num_>) || (op1 == Num_>=)) || + ((op2 == Num_!=) || (op2 == Num_>) || (op2 == Num_>=)) + + fold3WayComparison(canBeEqual, canBeLessThan, canBeGreaterThan, l1, r1) + + // Example: (x >= y) && (x <= y) -> (x == y) + case (BinaryOp(op1 @ (Num_== | Num_!= | Num_< | Num_<= | Num_> | Num_>=), l1, r1), + BinaryOp(op2 @ (Num_== | Num_!= | Num_< | Num_<= | Num_> | Num_>=), l2, r2), + BooleanLiteral(false)) + if ((l1.isInstanceOf[Literal] || l1.isInstanceOf[VarRef]) && + (r1.isInstanceOf[Literal] || r1.isInstanceOf[VarRef]) && + (l1 == l2 && r1 == r2)) => + val canBeEqual = + ((op1 == Num_==) || (op1 == Num_<=) || (op1 == Num_>=)) && + ((op2 == Num_==) || (op2 == Num_<=) || (op2 == Num_>=)) + val canBeLessThan = + ((op1 == Num_!=) || (op1 == Num_<) || (op1 == Num_<=)) && + ((op2 == Num_!=) || (op2 == Num_<) || (op2 == Num_<=)) + val canBeGreaterThan = + ((op1 == Num_!=) || (op1 == Num_>) || (op1 == Num_>=)) && + ((op2 == Num_!=) || (op2 == Num_>) || (op2 == Num_>=)) + + fold3WayComparison(canBeEqual, canBeLessThan, canBeGreaterThan, l1, r1) + + case _ => default + } + } else { + (thenp, elsep) match { + case (Skip(), Skip()) => keepOnlySideEffects(cond) + case (Skip(), _) => foldIf(negCond, elsep, thenp)(tpe) + + case _ => default + } + } + } + } + + private def transformUnaryOp(tree: UnaryOp)(implicit scope: Scope): Tree = { + import UnaryOp._ + + implicit val pos = tree.pos + val UnaryOp(op, arg) = tree + + (op: @switch) match { + case LongToInt => + trampoline { + pretransformExpr(arg) { (targ) => + TailCalls.done { + foldUnaryOp(op, finishTransformOptLongExpr(targ)) + } + } + } + + case _ => + foldUnaryOp(op, transformExpr(arg)) + } + } + + private def transformBinaryOp(tree: BinaryOp)(implicit scope: Scope): Tree = { + import BinaryOp._ + + implicit val pos = tree.pos + val BinaryOp(op, lhs, rhs) = tree + + (op: @switch) match { + case === | !== => + trampoline { + pretransformExprs(lhs, rhs) { (tlhs, trhs) => + TailCalls.done(foldReferenceEquality(tlhs, trhs, op == ===)) + } + } + + case Long_== | Long_!= | Long_< | Long_<= | Long_> | Long_>= => + trampoline { + pretransformExprs(lhs, rhs) { (tlhs, trhs) => + TailCalls.done { + if (isLiteralOrOptimizableLong(tlhs) && + isLiteralOrOptimizableLong(trhs)) { + foldBinaryOp(op, finishTransformOptLongExpr(tlhs), + finishTransformOptLongExpr(trhs)) + } else { + foldBinaryOp(op, finishTransformExpr(tlhs), + finishTransformExpr(trhs)) + } + } + } + } + + case _ => + foldBinaryOp(op, transformExpr(lhs), transformExpr(rhs)) + } + } + + private def isLiteralOrOptimizableLong(texpr: PreTransform): Boolean = { + texpr match { + case PreTransTree(LongLiteral(_), _) => + true + case PreTransLocalDef(LocalDef(_, _, replacement)) => + replacement match { + case ReplaceWithVarRef(_, _, _, Some(_)) => true + case ReplaceWithConstant(LongLiteral(_)) => true + case _ => false + } + case _ => + false + } + } + + private def finishTransformOptLongExpr(targ: PreTransform): Tree = targ match { + case PreTransLocalDef(LocalDef(tpe, false, + ReplaceWithVarRef(_, _, _, Some(argValue)))) => + argValue() + case _ => + finishTransformExpr(targ) + } + + private def foldUnaryOp(op: UnaryOp.Code, arg: Tree)( + implicit pos: Position): Tree = { + import UnaryOp._ + @inline def default = UnaryOp(op, arg) + (op: @switch) match { + case Boolean_! => + arg match { + case BooleanLiteral(v) => BooleanLiteral(!v) + case UnaryOp(Boolean_!, x) => x + + case BinaryOp(innerOp, l, r) => + val newOp = (innerOp: @switch) match { + case BinaryOp.=== => BinaryOp.!== + case BinaryOp.!== => BinaryOp.=== + + case BinaryOp.Num_== => BinaryOp.Num_!= + case BinaryOp.Num_!= => BinaryOp.Num_== + case BinaryOp.Num_< => BinaryOp.Num_>= + case BinaryOp.Num_<= => BinaryOp.Num_> + case BinaryOp.Num_> => BinaryOp.Num_<= + case BinaryOp.Num_>= => BinaryOp.Num_< + + case BinaryOp.Long_== => BinaryOp.Long_!= + case BinaryOp.Long_!= => BinaryOp.Long_== + case BinaryOp.Long_< => BinaryOp.Long_>= + case BinaryOp.Long_<= => BinaryOp.Long_> + case BinaryOp.Long_> => BinaryOp.Long_<= + case BinaryOp.Long_>= => BinaryOp.Long_< + + case BinaryOp.Boolean_== => BinaryOp.Boolean_!= + case BinaryOp.Boolean_!= => BinaryOp.Boolean_== + + case _ => -1 + } + if (newOp == -1) default + else BinaryOp(newOp, l, r) + + case _ => default + } + + case IntToLong => + arg match { + case IntLiteral(v) => LongLiteral(v.toLong) + case _ => default + } + + case LongToInt => + arg match { + case LongLiteral(v) => IntLiteral(v.toInt) + case UnaryOp(IntToLong, x) => x + + case BinaryOp(BinaryOp.Long_+, x, y) => + foldBinaryOp(BinaryOp.Int_+, + foldUnaryOp(LongToInt, x), + foldUnaryOp(LongToInt, y)) + case BinaryOp(BinaryOp.Long_-, x, y) => + foldBinaryOp(BinaryOp.Int_-, + foldUnaryOp(LongToInt, x), + foldUnaryOp(LongToInt, y)) + + case _ => default + } + + case LongToDouble => + arg match { + case LongLiteral(v) => DoubleLiteral(v.toDouble) + case _ => default + } + case DoubleToInt => + arg match { + case _ if arg.tpe == IntType => arg + case NumberLiteral(v) => IntLiteral(v.toInt) + case _ => default + } + case DoubleToFloat => + arg match { + case _ if arg.tpe == FloatType => arg + case NumberLiteral(v) => FloatLiteral(v.toFloat) + case _ => default + } + case DoubleToLong => + arg match { + case _ if arg.tpe == IntType => foldUnaryOp(IntToLong, arg) + case NumberLiteral(v) => LongLiteral(v.toLong) + case _ => default + } + case _ => + default + } + } + + /** Performs === for two literals. + * The result is always known statically. + */ + private def literal_===(lhs: Literal, rhs: Literal): Boolean = { + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => l == r + case (FloatLiteral(l), FloatLiteral(r)) => l == r + case (NumberLiteral(l), NumberLiteral(r)) => l == r + case (LongLiteral(l), LongLiteral(r)) => l == r + case (BooleanLiteral(l), BooleanLiteral(r)) => l == r + case (StringLiteral(l), StringLiteral(r)) => l == r + case (Undefined(), Undefined()) => true + case (Null(), Null()) => true + case _ => false + } + } + + private def foldBinaryOp(op: BinaryOp.Code, lhs: Tree, rhs: Tree)( + implicit pos: Position): Tree = { + import BinaryOp._ + @inline def default = BinaryOp(op, lhs, rhs) + (op: @switch) match { + case === | !== => + val positive = (op == ===) + (lhs, rhs) match { + case (lhs: Literal, rhs: Literal) => + BooleanLiteral(literal_===(lhs, rhs) == positive) + + case (_: Literal, _) => foldBinaryOp(op, rhs, lhs) + case _ => default + } + + case Int_+ => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l + r) + case (_, IntLiteral(_)) => foldBinaryOp(Int_+, rhs, lhs) + case (IntLiteral(0), _) => rhs + + case (IntLiteral(x), + BinaryOp(innerOp @ (Int_+ | Int_-), IntLiteral(y), z)) => + foldBinaryOp(innerOp, IntLiteral(x+y), z) + + case _ => default + } + + case Int_- => + (lhs, rhs) match { + case (_, IntLiteral(r)) => foldBinaryOp(Int_+, lhs, IntLiteral(-r)) + + case (IntLiteral(x), BinaryOp(Int_+, IntLiteral(y), z)) => + foldBinaryOp(Int_-, IntLiteral(x-y), z) + case (IntLiteral(x), BinaryOp(Int_-, IntLiteral(y), z)) => + foldBinaryOp(Int_+, IntLiteral(x-y), z) + + case (_, BinaryOp(Int_-, IntLiteral(0), x)) => + foldBinaryOp(Int_+, lhs, x) + + case _ => default + } + + case Int_* => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l * r) + case (_, IntLiteral(_)) => foldBinaryOp(Int_*, rhs, lhs) + + case (IntLiteral(1), _) => rhs + case (IntLiteral(-1), _) => foldBinaryOp(Int_-, IntLiteral(0), lhs) + + case _ => default + } + + case Int_/ => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) if r != 0 => IntLiteral(l / r) + + case (_, IntLiteral(1)) => lhs + case (_, IntLiteral(-1)) => foldBinaryOp(Int_-, IntLiteral(0), lhs) + + case _ => default + } + + case Int_% => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) if r != 0 => IntLiteral(l % r) + case (_, IntLiteral(1 | -1)) => + Block(keepOnlySideEffects(lhs), IntLiteral(0)) + case _ => default + } + + case Int_| => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l | r) + case (_, IntLiteral(_)) => foldBinaryOp(Int_|, rhs, lhs) + case (IntLiteral(0), _) => rhs + + case (IntLiteral(x), BinaryOp(Int_|, IntLiteral(y), z)) => + foldBinaryOp(Int_|, IntLiteral(x | y), z) + + case _ => default + } + + case Int_& => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l & r) + case (_, IntLiteral(_)) => foldBinaryOp(Int_&, rhs, lhs) + case (IntLiteral(-1), _) => rhs + + case (IntLiteral(x), BinaryOp(Int_&, IntLiteral(y), z)) => + foldBinaryOp(Int_&, IntLiteral(x & y), z) + + case _ => default + } + + case Int_^ => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l ^ r) + case (_, IntLiteral(_)) => foldBinaryOp(Int_^, rhs, lhs) + case (IntLiteral(0), _) => rhs + + case (IntLiteral(x), BinaryOp(Int_^, IntLiteral(y), z)) => + foldBinaryOp(Int_^, IntLiteral(x ^ y), z) + + case _ => default + } + + case Int_<< => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l << r) + case (_, IntLiteral(x)) if x % 32 == 0 => lhs + case _ => default + } + + case Int_>>> => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l >>> r) + case (_, IntLiteral(x)) if x % 32 == 0 => lhs + case _ => default + } + + case Int_>> => + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => IntLiteral(l >> r) + case (_, IntLiteral(x)) if x % 32 == 0 => lhs + case _ => default + } + + case Long_+ => + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l + r) + case (_, LongLiteral(_)) => foldBinaryOp(Long_+, rhs, lhs) + case (LongLiteral(0), _) => rhs + + case (LongLiteral(x), + BinaryOp(innerOp @ (Long_+ | Long_-), LongLiteral(y), z)) => + foldBinaryOp(innerOp, LongLiteral(x+y), z) + + case _ => default + } + + case Long_- => + (lhs, rhs) match { + case (_, LongLiteral(r)) => foldBinaryOp(Long_+, LongLiteral(-r), lhs) + + case (LongLiteral(x), BinaryOp(Long_+, LongLiteral(y), z)) => + foldBinaryOp(Long_-, LongLiteral(x-y), z) + case (LongLiteral(x), BinaryOp(Long_-, LongLiteral(y), z)) => + foldBinaryOp(Long_+, LongLiteral(x-y), z) + + case (_, BinaryOp(BinaryOp.Long_-, LongLiteral(0L), x)) => + foldBinaryOp(Long_+, lhs, x) + + case _ => default + } + + case Long_* => + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l * r) + case (_, LongLiteral(_)) => foldBinaryOp(Long_*, rhs, lhs) + + case (LongLiteral(1), _) => rhs + case (LongLiteral(-1), _) => foldBinaryOp(Long_-, LongLiteral(0), lhs) + + case _ => default + } + + case Long_/ => + (lhs, rhs) match { + case (_, LongLiteral(0)) => default + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l / r) + + case (_, LongLiteral(1)) => lhs + case (_, LongLiteral(-1)) => foldBinaryOp(Long_-, LongLiteral(0), lhs) + + case (LongFromInt(x), LongFromInt(y: IntLiteral)) if y.value != -1 => + LongFromInt(foldBinaryOp(Int_/, x, y)) + + case _ => default + } + + case Long_% => + (lhs, rhs) match { + case (_, LongLiteral(0)) => default + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l % r) + + case (_, LongLiteral(1L | -1L)) => + Block(keepOnlySideEffects(lhs), LongLiteral(0L)) + + case (LongFromInt(x), LongFromInt(y)) => + LongFromInt(foldBinaryOp(Int_%, x, y)) + + case _ => default + } + + case Long_| => + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l | r) + case (_, LongLiteral(_)) => foldBinaryOp(Long_|, rhs, lhs) + case (LongLiteral(0), _) => rhs + + case (LongLiteral(x), BinaryOp(Long_|, LongLiteral(y), z)) => + foldBinaryOp(Long_|, LongLiteral(x | y), z) + + case _ => default + } + + case Long_& => + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l & r) + case (_, LongLiteral(_)) => foldBinaryOp(Long_&, rhs, lhs) + case (LongLiteral(-1), _) => rhs + + case (LongLiteral(x), BinaryOp(Long_&, LongLiteral(y), z)) => + foldBinaryOp(Long_&, LongLiteral(x & y), z) + + case _ => default + } + + case Long_^ => + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => LongLiteral(l ^ r) + case (_, LongLiteral(_)) => foldBinaryOp(Long_^, rhs, lhs) + case (LongLiteral(0), _) => rhs + + case (LongLiteral(x), BinaryOp(Long_^, LongLiteral(y), z)) => + foldBinaryOp(Long_^, LongLiteral(x ^ y), z) + + case _ => default + } + + case Long_<< => + (lhs, rhs) match { + case (LongLiteral(l), IntLiteral(r)) => LongLiteral(l << r) + case (_, IntLiteral(x)) if x % 64 == 0 => lhs + case _ => default + } + + case Long_>>> => + (lhs, rhs) match { + case (LongLiteral(l), IntLiteral(r)) => LongLiteral(l >>> r) + case (_, IntLiteral(x)) if x % 64 == 0 => lhs + case _ => default + } + + case Long_>> => + (lhs, rhs) match { + case (LongLiteral(l), IntLiteral(r)) => LongLiteral(l >> r) + case (_, IntLiteral(x)) if x % 64 == 0 => lhs + case _ => default + } + + case Long_== | Long_!= => + val positive = (op == Long_==) + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => + BooleanLiteral((l == r) == positive) + + case (LongFromInt(x), LongFromInt(y)) => + foldBinaryOp(if (positive) === else !==, x, y) + case (LongFromInt(x), LongLiteral(y)) => + assert(y > Int.MaxValue || y < Int.MinValue) + Block(keepOnlySideEffects(x), BooleanLiteral(!positive)) + + case (BinaryOp(Long_+, LongLiteral(x), y), LongLiteral(z)) => + foldBinaryOp(op, y, LongLiteral(z-x)) + case (BinaryOp(Long_-, LongLiteral(x), y), LongLiteral(z)) => + foldBinaryOp(op, y, LongLiteral(x-z)) + + case (LongLiteral(_), _) => foldBinaryOp(op, rhs, lhs) + case _ => default + } + + case Long_< | Long_<= | Long_> | Long_>= => + def flippedOp = (op: @switch) match { + case Long_< => Long_> + case Long_<= => Long_>= + case Long_> => Long_< + case Long_>= => Long_<= + } + + def intOp = (op: @switch) match { + case Long_< => Num_< + case Long_<= => Num_<= + case Long_> => Num_> + case Long_>= => Num_>= + } + + (lhs, rhs) match { + case (LongLiteral(l), LongLiteral(r)) => + val result = (op: @switch) match { + case Long_< => l < r + case Long_<= => l <= r + case Long_> => l > r + case Long_>= => l >= r + } + BooleanLiteral(result) + + case (_, LongLiteral(Long.MinValue)) => + if (op == Long_< || op == Long_>=) + Block(keepOnlySideEffects(lhs), BooleanLiteral(op == Long_>=)) + else + foldBinaryOp(if (op == Long_<=) Long_== else Long_!=, lhs, rhs) + + case (_, LongLiteral(Long.MaxValue)) => + if (op == Long_> || op == Long_<=) + Block(keepOnlySideEffects(lhs), BooleanLiteral(op == Long_<=)) + else + foldBinaryOp(if (op == Long_>=) Long_== else Long_!=, lhs, rhs) + + case (LongFromInt(x), LongFromInt(y)) => + foldBinaryOp(intOp, x, y) + case (LongFromInt(x), LongLiteral(y)) => + assert(y > Int.MaxValue || y < Int.MinValue) + val result = + if (y > Int.MaxValue) op == Long_< || op == Long_<= + else op == Long_> || op == Long_>= + Block(keepOnlySideEffects(x), BooleanLiteral(result)) + + /* x + y.toLong > z + * -x on both sides + * requires x + y.toLong not to overflow, and z - x likewise + * y.toLong > z - x + */ + case (BinaryOp(Long_+, LongLiteral(x), y @ LongFromInt(_)), LongLiteral(z)) + if canAddLongs(x, Int.MinValue) && + canAddLongs(x, Int.MaxValue) && + canSubtractLongs(z, x) => + foldBinaryOp(op, y, LongLiteral(z-x)) + + /* x - y.toLong > z + * -x on both sides + * requires x - y.toLong not to overflow, and z - x likewise + * -(y.toLong) > z - x + */ + case (BinaryOp(Long_-, LongLiteral(x), y @ LongFromInt(_)), LongLiteral(z)) + if canSubtractLongs(x, Int.MinValue) && + canSubtractLongs(x, Int.MaxValue) && + canSubtractLongs(z, x) => + if (z-x != Long.MinValue) { + // Since -(y.toLong) does not overflow, we can negate both sides + foldBinaryOp(flippedOp, y, LongLiteral(-(z-x))) + } else { + /* -(y.toLong) > Long.MinValue + * Depending on the operator, this is either always true or + * always false. + */ + val result = (op == Long_>) || (op == Long_>=) + Block(keepOnlySideEffects(y), BooleanLiteral(result)) + } + + /* x.toLong + y.toLong > Int.MaxValue.toLong + * + * This is basically testing whether x+y overflows in positive. + * If x <= 0 or y <= 0, this cannot happen -> false. + * If x > 0 and y > 0, this can be detected with x+y < 0. + * Therefore, we rewrite as: + * + * x > 0 && y > 0 && x+y < 0. + * + * This requires to evaluate x and y once. + */ + case (BinaryOp(Long_+, LongFromInt(x), LongFromInt(y)), + LongLiteral(Int.MaxValue)) => + trampoline { + withNewLocalDefs(List( + Binding("x", None, IntType, false, PreTransTree(x)), + Binding("y", None, IntType, false, PreTransTree(y)))) { + (tempsLocalDefs, cont) => + val List(tempXDef, tempYDef) = tempsLocalDefs + val tempX = tempXDef.newReplacement + val tempY = tempYDef.newReplacement + cont(PreTransTree( + AndThen(AndThen( + BinaryOp(Num_>, tempX, IntLiteral(0)), + BinaryOp(Num_>, tempY, IntLiteral(0))), + BinaryOp(Num_<, BinaryOp(Int_+, tempX, tempY), IntLiteral(0))))) + } (finishTransform(isStat = false)) + } + + case (LongLiteral(_), _) => foldBinaryOp(flippedOp, rhs, lhs) + case _ => default + } + + case Float_+ => + (lhs, rhs) match { + case (FloatLiteral(l), FloatLiteral(r)) => FloatLiteral(l + r) + case (FloatLiteral(0), _) => rhs + case (_, FloatLiteral(_)) => foldBinaryOp(Float_+, rhs, lhs) + + case (FloatLiteral(x), + BinaryOp(innerOp @ (Float_+ | Float_-), FloatLiteral(y), z)) => + foldBinaryOp(innerOp, FloatLiteral(x+y), z) + + case _ => default + } + + case Float_- => + (lhs, rhs) match { + case (_, FloatLiteral(r)) => foldBinaryOp(Float_+, lhs, FloatLiteral(-r)) + + case (FloatLiteral(x), BinaryOp(Float_+, FloatLiteral(y), z)) => + foldBinaryOp(Float_-, FloatLiteral(x-y), z) + case (FloatLiteral(x), BinaryOp(Float_-, FloatLiteral(y), z)) => + foldBinaryOp(Float_+, FloatLiteral(x-y), z) + + case (_, BinaryOp(BinaryOp.Float_-, FloatLiteral(0), x)) => + foldBinaryOp(Float_+, lhs, x) + + case _ => default + } + + case Float_* => + (lhs, rhs) match { + case (FloatLiteral(l), FloatLiteral(r)) => FloatLiteral(l * r) + case (_, FloatLiteral(_)) => foldBinaryOp(Float_*, rhs, lhs) + + case (FloatLiteral(1), _) => rhs + case (FloatLiteral(-1), _) => foldBinaryOp(Float_-, FloatLiteral(0), lhs) + + case _ => default + } + + case Float_/ => + (lhs, rhs) match { + case (FloatLiteral(l), FloatLiteral(r)) => FloatLiteral(l / r) + + case (_, FloatLiteral(1)) => lhs + case (_, FloatLiteral(-1)) => foldBinaryOp(Float_-, FloatLiteral(0), lhs) + + case _ => default + } + + case Float_% => + (lhs, rhs) match { + case (FloatLiteral(l), FloatLiteral(r)) => FloatLiteral(l % r) + case _ => default + } + + case Double_+ => + (lhs, rhs) match { + case (NumberLiteral(l), NumberLiteral(r)) => DoubleLiteral(l + r) + case (NumberLiteral(0), _) => rhs + case (_, NumberLiteral(_)) => foldBinaryOp(Double_+, rhs, lhs) + + case (NumberLiteral(x), + BinaryOp(innerOp @ (Double_+ | Double_-), NumberLiteral(y), z)) => + foldBinaryOp(innerOp, DoubleLiteral(x+y), z) + + case _ => default + } + + case Double_- => + (lhs, rhs) match { + case (_, NumberLiteral(r)) => foldBinaryOp(Double_+, lhs, DoubleLiteral(-r)) + + case (NumberLiteral(x), BinaryOp(Double_+, NumberLiteral(y), z)) => + foldBinaryOp(Double_-, DoubleLiteral(x-y), z) + case (NumberLiteral(x), BinaryOp(Double_-, NumberLiteral(y), z)) => + foldBinaryOp(Double_+, DoubleLiteral(x-y), z) + + case (_, BinaryOp(BinaryOp.Double_-, NumberLiteral(0), x)) => + foldBinaryOp(Double_+, lhs, x) + + case _ => default + } + + case Double_* => + (lhs, rhs) match { + case (NumberLiteral(l), NumberLiteral(r)) => DoubleLiteral(l * r) + case (_, NumberLiteral(_)) => foldBinaryOp(Double_*, rhs, lhs) + + case (NumberLiteral(1), _) => rhs + case (NumberLiteral(-1), _) => foldBinaryOp(Double_-, DoubleLiteral(0), lhs) + + case _ => default + } + + case Double_/ => + (lhs, rhs) match { + case (NumberLiteral(l), NumberLiteral(r)) => DoubleLiteral(l / r) + + case (_, NumberLiteral(1)) => lhs + case (_, NumberLiteral(-1)) => foldBinaryOp(Double_-, DoubleLiteral(0), lhs) + + case _ => default + } + + case Double_% => + (lhs, rhs) match { + case (NumberLiteral(l), NumberLiteral(r)) => DoubleLiteral(l % r) + case _ => default + } + + case Boolean_== | Boolean_!= => + val positive = (op == Boolean_==) + (lhs, rhs) match { + case (BooleanLiteral(l), _) => + if (l == positive) rhs + else foldUnaryOp(UnaryOp.Boolean_!, rhs) + case (_, BooleanLiteral(r)) => + if (r == positive) lhs + else foldUnaryOp(UnaryOp.Boolean_!, lhs) + case _ => + default + } + + case Boolean_| => + (lhs, rhs) match { + case (_, BooleanLiteral(false)) => lhs + case (BooleanLiteral(false), _) => rhs + case _ => default + } + + case Boolean_& => + (lhs, rhs) match { + case (_, BooleanLiteral(true)) => lhs + case (BooleanLiteral(true), _) => rhs + case _ => default + } + + case Num_== | Num_!= => + val positive = (op == Num_==) + (lhs, rhs) match { + case (lhs: Literal, rhs: Literal) => + BooleanLiteral(literal_===(lhs, rhs) == positive) + + case (BinaryOp(Int_+, IntLiteral(x), y), IntLiteral(z)) => + foldBinaryOp(op, y, IntLiteral(z-x)) + case (BinaryOp(Int_-, IntLiteral(x), y), IntLiteral(z)) => + foldBinaryOp(op, y, IntLiteral(x-z)) + + case (_: Literal, _) => foldBinaryOp(op, rhs, lhs) + case _ => default + } + + case Num_< | Num_<= | Num_> | Num_>= => + def flippedOp = (op: @switch) match { + case Num_< => Num_> + case Num_<= => Num_>= + case Num_> => Num_< + case Num_>= => Num_<= + } + + if (lhs.tpe == IntType && rhs.tpe == IntType) { + (lhs, rhs) match { + case (IntLiteral(l), IntLiteral(r)) => + val result = (op: @switch) match { + case Num_< => l < r + case Num_<= => l <= r + case Num_> => l > r + case Num_>= => l >= r + } + BooleanLiteral(result) + + case (_, IntLiteral(Int.MinValue)) => + if (op == Num_< || op == Num_>=) + Block(keepOnlySideEffects(lhs), BooleanLiteral(op == Num_>=)) + else + foldBinaryOp(if (op == Num_<=) Num_== else Num_!=, lhs, rhs) + + case (_, IntLiteral(Int.MaxValue)) => + if (op == Num_> || op == Num_<=) + Block(keepOnlySideEffects(lhs), BooleanLiteral(op == Num_<=)) + else + foldBinaryOp(if (op == Num_>=) Num_== else Num_!=, lhs, rhs) + + case (IntLiteral(_), _) => foldBinaryOp(flippedOp, rhs, lhs) + case _ => default + } + } else { + (lhs, rhs) match { + case (NumberLiteral(l), NumberLiteral(r)) => + val result = (op: @switch) match { + case Num_< => l < r + case Num_<= => l <= r + case Num_> => l > r + case Num_>= => l >= r + } + BooleanLiteral(result) + + case _ => default + } + } + + case _ => + default + } + } + + private def fold3WayComparison(canBeEqual: Boolean, canBeLessThan: Boolean, + canBeGreaterThan: Boolean, lhs: Tree, rhs: Tree)( + implicit pos: Position): Tree = { + import BinaryOp._ + if (canBeEqual) { + if (canBeLessThan) { + if (canBeGreaterThan) + Block(keepOnlySideEffects(lhs), keepOnlySideEffects(rhs), BooleanLiteral(true)) + else + foldBinaryOp(Num_<=, lhs, rhs) + } else { + if (canBeGreaterThan) + foldBinaryOp(Num_>=, lhs, rhs) + else + foldBinaryOp(Num_==, lhs, rhs) + } + } else { + if (canBeLessThan) { + if (canBeGreaterThan) + foldBinaryOp(Num_!=, lhs, rhs) + else + foldBinaryOp(Num_<, lhs, rhs) + } else { + if (canBeGreaterThan) + foldBinaryOp(Num_>, lhs, rhs) + else + Block(keepOnlySideEffects(lhs), keepOnlySideEffects(rhs), BooleanLiteral(false)) + } + } + } + + private def foldUnbox(arg: PreTransform, charCode: Char)( + cont: PreTransCont): TailRec[Tree] = { + (charCode: @switch) match { + case 'Z' if arg.tpe.base == BooleanType => cont(arg) + case 'I' if arg.tpe.base == IntType => cont(arg) + case 'F' if arg.tpe.base == FloatType => cont(arg) + case 'J' if arg.tpe.base == LongType => cont(arg) + case 'D' if arg.tpe.base == DoubleType || + arg.tpe.base == IntType || arg.tpe.base == FloatType => cont(arg) + case _ => + cont(PreTransTree(Unbox(finishTransformExpr(arg), charCode)(arg.pos))) + } + } + + private def foldReferenceEquality(tlhs: PreTransform, trhs: PreTransform, + positive: Boolean = true)(implicit pos: Position): Tree = { + (tlhs, trhs) match { + case (_, PreTransTree(Null(), _)) if !tlhs.tpe.isNullable => + Block( + finishTransformStat(tlhs), + BooleanLiteral(!positive)) + case (PreTransTree(Null(), _), _) if !trhs.tpe.isNullable => + Block( + finishTransformStat(trhs), + BooleanLiteral(!positive)) + case _ => + foldBinaryOp(if (positive) BinaryOp.=== else BinaryOp.!==, + finishTransformExpr(tlhs), finishTransformExpr(trhs)) + } + } + + private def finishTransformCheckNull(preTrans: PreTransform)( + implicit pos: Position): Tree = { + if (preTrans.tpe.isNullable) { + val transformed = finishTransformExpr(preTrans) + CallHelper("checkNonNull", transformed)(transformed.tpe) + } else { + finishTransformExpr(preTrans) + } + } + + def transformIsolatedBody(optTarget: Option[MethodID], + thisType: Type, params: List[ParamDef], resultType: Type, + body: Tree): (List[ParamDef], Tree) = { + val (paramLocalDefs, newParamDefs) = (for { + p @ ParamDef(ident @ Ident(name, originalName), ptpe, mutable) <- params + } yield { + val newName = freshLocalName(name) + val newOriginalName = originalName.orElse(Some(newName)) + val localDef = LocalDef(RefinedType(ptpe), mutable, + ReplaceWithVarRef(newName, newOriginalName, new SimpleState(true), None)) + val newParamDef = ParamDef( + Ident(newName, newOriginalName)(ident.pos), ptpe, mutable)(p.pos) + ((name -> localDef), newParamDef) + }).unzip + + val thisLocalDef = + if (thisType == NoType) None + else { + Some("this" -> LocalDef( + RefinedType(thisType, isExact = false, isNullable = false), + false, ReplaceWithThis())) + } + + val allLocalDefs = thisLocalDef ++: paramLocalDefs + + val scope0 = optTarget.fold(Scope.Empty)( + target => Scope.Empty.inlining((None, target))) + val scope = scope0.withEnv(OptEnv.Empty.withLocalDefs(allLocalDefs)) + val newBody = + transform(body, resultType == NoType)(scope) + + (newParamDefs, newBody) + } + + private def returnable(oldLabelName: String, resultType: Type, + body: Tree, isStat: Boolean, usePreTransform: Boolean)( + cont: PreTransCont)( + implicit scope: Scope, pos: Position): TailRec[Tree] = tailcall { + val newLabel = freshLabelName( + if (oldLabelName.isEmpty) "inlinereturn" else oldLabelName) + + def doMakeTree(newBody: Tree, returnedTypes: List[Type]): Tree = { + val refinedType = + returnedTypes.reduce(constrainedLub(_, _, resultType)) + val returnCount = returnedTypes.size - 1 + + tryOptimizePatternMatch(oldLabelName, refinedType, + returnCount, newBody) getOrElse { + Labeled(Ident(newLabel, None), refinedType, newBody) + } + } + + val info = new LabelInfo(newLabel, acceptRecords = usePreTransform) + withState(info.returnedTypes) { + val bodyScope = scope.withEnv(scope.env.withLabelInfo(oldLabelName, info)) + + if (usePreTransform) { + assert(!isStat, "Cannot use pretransform in statement position") + tryOrRollback { cancelFun => + pretransformExpr(body) { tbody0 => + val returnedTypes0 = info.returnedTypes.value + if (returnedTypes0.isEmpty) { + // no return to that label, we can eliminate it + cont(tbody0) + } else { + val tbody = resolveLocalDef(tbody0) + val (newBody, returnedTypes) = tbody match { + case PreTransRecordTree(bodyTree, origType, _) => + (bodyTree, (bodyTree.tpe, origType) :: returnedTypes0) + case PreTransTree(bodyTree, tpe) => + (bodyTree, (bodyTree.tpe, tpe) :: returnedTypes0) + } + val (actualTypes, origTypes) = returnedTypes.unzip + val refinedOrigType = + origTypes.reduce(constrainedLub(_, _, resultType)) + actualTypes.collectFirst { + case actualType: RecordType => actualType + }.fold[TailRec[Tree]] { + // None of the returned types are records + cont(PreTransTree( + doMakeTree(newBody, actualTypes), refinedOrigType)) + } { recordType => + if (actualTypes.exists(t => t != recordType && t != NothingType)) + cancelFun() + + val resultTree = doMakeTree(newBody, actualTypes) + + if (origTypes.exists(t => t != refinedOrigType && !t.isNothingType)) + cancelFun() + + cont(PreTransRecordTree(resultTree, refinedOrigType, cancelFun)) + } + } + } (bodyScope) + } { () => + returnable(oldLabelName, resultType, body, isStat, + usePreTransform = false)(cont) + } + } else { + val newBody = transform(body, isStat)(bodyScope) + val returnedTypes0 = info.returnedTypes.value.map(_._1) + if (returnedTypes0.isEmpty) { + // no return to that label, we can eliminate it + cont(PreTransTree(newBody, RefinedType(newBody.tpe))) + } else { + val returnedTypes = newBody.tpe :: returnedTypes0 + val tree = doMakeTree(newBody, returnedTypes) + cont(PreTransTree(tree, RefinedType(tree.tpe))) + } + } + } + } + + def tryOptimizePatternMatch(oldLabelName: String, refinedType: Type, + returnCount: Int, newBody: Tree): Option[Tree] = { + if (!oldLabelName.startsWith("matchEnd")) None + else { + newBody match { + case Block(stats) => + @tailrec + def createRevAlts(xs: List[Tree], acc: List[(Tree, Tree)]): List[(Tree, Tree)] = xs match { + case If(cond, body, Skip()) :: xr => + createRevAlts(xr, (cond, body) :: acc) + case remaining => + (EmptyTree, Block(remaining)(remaining.head.pos)) :: acc + } + val revAlts = createRevAlts(stats, Nil) + + if (revAlts.size == returnCount) { + @tailrec + def constructOptimized(revAlts: List[(Tree, Tree)], elsep: Tree): Option[Tree] = { + revAlts match { + case (cond, body) :: revAltsRest => + body match { + case BlockOrAlone(prep, + Return(result, Some(Ident(newLabel, _)))) => + val result1 = + if (refinedType == NoType) keepOnlySideEffects(result) + else result + val prepAndResult = Block(prep :+ result1)(body.pos) + if (cond == EmptyTree) { + assert(elsep == EmptyTree) + constructOptimized(revAltsRest, prepAndResult) + } else { + assert(elsep != EmptyTree) + constructOptimized(revAltsRest, + foldIf(cond, prepAndResult, elsep)(refinedType)(cond.pos)) + } + case _ => + None + } + case Nil => + Some(elsep) + } + } + constructOptimized(revAlts, EmptyTree) + } else None + case _ => + None + } + } + } + + private def withBindings(bindings: List[Binding])( + buildInner: (Scope, PreTransCont) => TailRec[Tree])( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + withNewLocalDefs(bindings) { (localDefs, cont1) => + val newMappings = for { + (binding, localDef) <- bindings zip localDefs + } yield { + binding.name -> localDef + } + buildInner(scope.withEnv(scope.env.withLocalDefs(newMappings)), cont1) + } (cont) + } + + private def withBinding(binding: Binding)( + buildInner: (Scope, PreTransCont) => TailRec[Tree])( + cont: PreTransCont)( + implicit scope: Scope): TailRec[Tree] = { + withNewLocalDef(binding) { (localDef, cont1) => + buildInner(scope.withEnv(scope.env.withLocalDef(binding.name, localDef)), + cont1) + } (cont) + } + + private def withNewLocalDefs(bindings: List[Binding])( + buildInner: (List[LocalDef], PreTransCont) => TailRec[Tree])( + cont: PreTransCont): TailRec[Tree] = { + bindings match { + case first :: rest => + withNewLocalDef(first) { (firstLocalDef, cont1) => + withNewLocalDefs(rest) { (restLocalDefs, cont2) => + buildInner(firstLocalDef :: restLocalDefs, cont2) + } (cont1) + } (cont) + + case Nil => + buildInner(Nil, cont) + } + } + + private def isImmutableType(tpe: Type): Boolean = tpe match { + case RecordType(fields) => + fields.forall(f => !f.mutable && isImmutableType(f.tpe)) + case _ => + true + } + + private def withNewLocalDef(binding: Binding)( + buildInner: (LocalDef, PreTransCont) => TailRec[Tree])( + cont: PreTransCont): TailRec[Tree] = tailcall { + val Binding(name, originalName, declaredType, mutable, value) = binding + implicit val pos = value.pos + + def withDedicatedVar(tpe: RefinedType): TailRec[Tree] = { + val newName = freshLocalName(name) + val newOriginalName = originalName.orElse(Some(name)) + + val used = new SimpleState(false) + withState(used) { + def doBuildInner(localDef: LocalDef)(varDef: => VarDef)( + cont: PreTransCont): TailRec[Tree] = { + buildInner(localDef, { tinner => + if (used.value) { + cont(PreTransBlock(varDef :: Nil, tinner)) + } else { + tinner match { + case PreTransLocalDef(`localDef`) => + cont(value) + case _ if tinner.contains(localDef) => + cont(PreTransBlock(varDef :: Nil, tinner)) + case _ => + val rhsSideEffects = finishTransformStat(value) + rhsSideEffects match { + case Skip() => + cont(tinner) + case _ => + if (rhsSideEffects.tpe == NothingType) + cont(PreTransTree(rhsSideEffects, RefinedType.Nothing)) + else + cont(PreTransBlock(rhsSideEffects :: Nil, tinner)) + } + } + } + }) + } + + resolveLocalDef(value) match { + case PreTransRecordTree(valueTree, valueTpe, cancelFun) => + val recordType = valueTree.tpe.asInstanceOf[RecordType] + if (!isImmutableType(recordType)) + cancelFun() + val localDef = LocalDef(valueTpe, mutable, + ReplaceWithRecordVarRef(newName, newOriginalName, recordType, + used, cancelFun)) + doBuildInner(localDef) { + VarDef(Ident(newName, newOriginalName), recordType, mutable, + valueTree) + } (cont) + + case PreTransTree(valueTree, valueTpe) => + def doDoBuildInner(optValueTree: Option[() => Tree])( + cont1: PreTransCont) = { + val localDef = LocalDef(tpe, mutable, ReplaceWithVarRef( + newName, newOriginalName, used, optValueTree)) + doBuildInner(localDef) { + VarDef(Ident(newName, newOriginalName), tpe.base, mutable, + optValueTree.fold(valueTree)(_())) + } (cont1) + } + if (mutable) { + doDoBuildInner(None)(cont) + } else (valueTree match { + case LongFromInt(arg) => + withNewLocalDef( + Binding("x", None, IntType, false, PreTransTree(arg))) { + (intLocalDef, cont1) => + doDoBuildInner(Some( + () => LongFromInt(intLocalDef.newReplacement)))( + cont1) + } (cont) + + case BinaryOp(op @ (BinaryOp.Long_+ | BinaryOp.Long_-), + LongFromInt(intLhs), LongFromInt(intRhs)) => + withNewLocalDefs(List( + Binding("x", None, IntType, false, PreTransTree(intLhs)), + Binding("y", None, IntType, false, PreTransTree(intRhs)))) { + (intLocalDefs, cont1) => + val List(lhsLocalDef, rhsLocalDef) = intLocalDefs + doDoBuildInner(Some( + () => BinaryOp(op, + LongFromInt(lhsLocalDef.newReplacement), + LongFromInt(rhsLocalDef.newReplacement))))( + cont1) + } (cont) + + case _ => + doDoBuildInner(None)(cont) + }) + } + } + } + + if (value.tpe.isNothingType) { + cont(value) + } else if (mutable) { + withDedicatedVar(RefinedType(declaredType)) + } else { + val refinedType = value.tpe + value match { + case PreTransBlock(stats, result) => + withNewLocalDef(binding.copy(value = result))(buildInner) { tresult => + cont(PreTransBlock(stats, tresult)) + } + + case PreTransLocalDef(localDef) if !localDef.mutable => + buildInner(localDef, cont) + + case PreTransTree(literal: Literal, _) => + buildInner(LocalDef(refinedType, false, + ReplaceWithConstant(literal)), cont) + + case PreTransTree(VarRef(Ident(refName, refOriginalName), false), _) => + buildInner(LocalDef(refinedType, false, + ReplaceWithVarRef(refName, refOriginalName, + new SimpleState(true), None)), cont) + + case _ => + withDedicatedVar(refinedType) + } + } + } + + /** Finds a type as precise as possible which is a supertype of lhs and rhs + * but still a subtype of upperBound. + * Requires that lhs and rhs be subtypes of upperBound, obviously. + */ + private def constrainedLub(lhs: RefinedType, rhs: RefinedType, + upperBound: Type): RefinedType = { + if (upperBound == NoType) RefinedType(upperBound) + else if (lhs == rhs) lhs + else if (lhs.isNothingType) rhs + else if (rhs.isNothingType) lhs + else { + RefinedType(constrainedLub(lhs.base, rhs.base, upperBound), + false, lhs.isNullable || rhs.isNullable) + } + } + + /** Finds a type as precise as possible which is a supertype of lhs and rhs + * but still a subtype of upperBound. + * Requires that lhs and rhs be subtypes of upperBound, obviously. + */ + private def constrainedLub(lhs: Type, rhs: Type, upperBound: Type): Type = { + // TODO Improve this + if (upperBound == NoType) upperBound + else if (lhs == rhs) lhs + else if (lhs == NothingType) rhs + else if (rhs == NothingType) lhs + else upperBound + } + + /** Trampolines a pretransform */ + private def trampoline(tailrec: => TailRec[Tree]): Tree = { + curTrampolineId += 1 + + val myTrampolineId = curTrampolineId + + try { + var rec = () => tailrec + + while (true) { + try { + return rec().result + } catch { + case e: RollbackException if e.trampolineId == myTrampolineId => + rollbacksCount += 1 + if (rollbacksCount > MaxRollbacksPerMethod) + throw new TooManyRollbacksException + + usedLocalNames.clear() + usedLocalNames ++= e.savedUsedLocalNames + usedLabelNames.clear() + usedLabelNames ++= e.savedUsedLabelNames + for ((state, backup) <- statesInUse zip e.savedStates) + state.asInstanceOf[State[Any]].restore(backup) + + rec = e.cont + } + } + + sys.error("Reached end of infinite loop") + } finally { + curTrampolineId -= 1 + } + } +} + +private[optimizer] object OptimizerCore { + + private final val MaxRollbacksPerMethod = 256 + + private final class TooManyRollbacksException + extends scala.util.control.ControlThrowable + + private val AnonFunctionClassPrefix = "sjsr_AnonFunction" + + private type CancelFun = () => Nothing + private type PreTransCont = PreTransform => TailRec[Tree] + + private case class RefinedType private (base: Type, isExact: Boolean, + isNullable: Boolean)( + val allocationSite: Option[AllocationSite], dummy: Int = 0) { + + def isNothingType: Boolean = base == NothingType + } + + private object RefinedType { + def apply(base: Type, isExact: Boolean, isNullable: Boolean, + allocationSite: Option[AllocationSite]): RefinedType = + new RefinedType(base, isExact, isNullable)(allocationSite) + + def apply(base: Type, isExact: Boolean, isNullable: Boolean): RefinedType = + RefinedType(base, isExact, isNullable, None) + + def apply(tpe: Type): RefinedType = tpe match { + case BooleanType | IntType | FloatType | DoubleType | StringType | + UndefType | NothingType | _:RecordType | NoType => + RefinedType(tpe, isExact = true, isNullable = false) + case NullType => + RefinedType(tpe, isExact = true, isNullable = true) + case _ => + RefinedType(tpe, isExact = false, isNullable = true) + } + + val NoRefinedType = RefinedType(NoType) + val Nothing = RefinedType(NothingType) + } + + private class AllocationSite(private val node: Tree) { + override def equals(that: Any): Boolean = that match { + case that: AllocationSite => this.node eq that.node + case _ => false + } + + override def hashCode(): Int = + System.identityHashCode(node) + + override def toString(): String = + s"AllocationSite($node)" + } + + private case class LocalDef( + tpe: RefinedType, + mutable: Boolean, + replacement: LocalDefReplacement) { + + def newReplacement(implicit pos: Position): Tree = replacement match { + case ReplaceWithVarRef(name, originalName, used, _) => + used.value = true + VarRef(Ident(name, originalName), mutable)(tpe.base) + + case ReplaceWithRecordVarRef(_, _, _, _, cancelFun) => + cancelFun() + + case ReplaceWithThis() => + This()(tpe.base) + + case ReplaceWithConstant(value) => + value + + case TentativeClosureReplacement(_, _, _, _, _, cancelFun) => + cancelFun() + + case InlineClassBeingConstructedReplacement(_, cancelFun) => + cancelFun() + + case InlineClassInstanceReplacement(_, _, cancelFun) => + cancelFun() + } + + def contains(that: LocalDef): Boolean = { + (this eq that) || (replacement match { + case TentativeClosureReplacement(_, _, _, captureLocalDefs, _, _) => + captureLocalDefs.exists(_.contains(that)) + case InlineClassInstanceReplacement(_, fieldLocalDefs, _) => + fieldLocalDefs.valuesIterator.exists(_.contains(that)) + case _ => + false + }) + } + } + + private sealed abstract class LocalDefReplacement + + private final case class ReplaceWithVarRef(name: String, + originalName: Option[String], + used: SimpleState[Boolean], + longOpTree: Option[() => Tree]) extends LocalDefReplacement + + private final case class ReplaceWithRecordVarRef(name: String, + originalName: Option[String], + recordType: RecordType, + used: SimpleState[Boolean], + cancelFun: CancelFun) extends LocalDefReplacement + + private final case class ReplaceWithThis() extends LocalDefReplacement + + private final case class ReplaceWithConstant( + value: Tree) extends LocalDefReplacement + + private final case class TentativeClosureReplacement( + captureParams: List[ParamDef], params: List[ParamDef], body: Tree, + captureValues: List[LocalDef], + alreadyUsed: SimpleState[Boolean], + cancelFun: CancelFun) extends LocalDefReplacement + + private final case class InlineClassBeingConstructedReplacement( + fieldLocalDefs: Map[String, LocalDef], + cancelFun: CancelFun) extends LocalDefReplacement + + private final case class InlineClassInstanceReplacement( + recordType: RecordType, + fieldLocalDefs: Map[String, LocalDef], + cancelFun: CancelFun) extends LocalDefReplacement + + private final class LabelInfo( + val newName: String, + val acceptRecords: Boolean, + /** (actualType, originalType), actualType can be a RecordType. */ + val returnedTypes: SimpleState[List[(Type, RefinedType)]] = new SimpleState(Nil)) + + private class OptEnv( + val localDefs: Map[String, LocalDef], + val labelInfos: Map[String, LabelInfo]) { + + def withLocalDef(oldName: String, rep: LocalDef): OptEnv = + new OptEnv(localDefs + (oldName -> rep), labelInfos) + + def withLocalDefs(reps: List[(String, LocalDef)]): OptEnv = + new OptEnv(localDefs ++ reps, labelInfos) + + def withLabelInfo(oldName: String, info: LabelInfo): OptEnv = + new OptEnv(localDefs, labelInfos + (oldName -> info)) + + def withinFunction(paramLocalDefs: List[(String, LocalDef)]): OptEnv = + new OptEnv(localDefs ++ paramLocalDefs, Map.empty) + + override def toString(): String = { + "localDefs:"+localDefs.mkString("\n ", "\n ", "\n") + + "labelInfos:"+labelInfos.mkString("\n ", "\n ", "") + } + } + + private object OptEnv { + val Empty: OptEnv = new OptEnv(Map.empty, Map.empty) + } + + private class Scope(val env: OptEnv, + val implsBeingInlined: Set[(Option[AllocationSite], AbstractMethodID)]) { + def withEnv(env: OptEnv): Scope = + new Scope(env, implsBeingInlined) + + def inlining(impl: (Option[AllocationSite], AbstractMethodID)): Scope = { + assert(!implsBeingInlined(impl), s"Circular inlining of $impl") + new Scope(env, implsBeingInlined + impl) + } + } + + private object Scope { + val Empty: Scope = new Scope(OptEnv.Empty, Set.empty) + } + + /** The result of pretransformExpr(). + * It has a `tpe` as precisely refined as if a full transformExpr() had + * been performed. + * It is also not dependent on the environment anymore. In some sense, it + * has "captured" its environment at definition site. + */ + private sealed abstract class PreTransform { + def pos: Position + val tpe: RefinedType + + def contains(localDef: LocalDef): Boolean = this match { + case PreTransBlock(_, result) => + result.contains(localDef) + case PreTransLocalDef(thisLocalDef) => + thisLocalDef.contains(localDef) + case _ => + false + } + } + + private final class PreTransBlock private (val stats: List[Tree], + val result: PreTransLocalDef) extends PreTransform { + def pos = result.pos + val tpe = result.tpe + + assert(stats.nonEmpty) + + override def toString(): String = + s"PreTransBlock($stats,$result)" + } + + private object PreTransBlock { + def apply(stats: List[Tree], result: PreTransform): PreTransform = { + if (stats.isEmpty) result + else { + result match { + case PreTransBlock(innerStats, innerResult) => + new PreTransBlock(stats ++ innerStats, innerResult) + case result: PreTransLocalDef => + new PreTransBlock(stats, result) + case PreTransRecordTree(tree, tpe, cancelFun) => + PreTransRecordTree(Block(stats :+ tree)(tree.pos), tpe, cancelFun) + case PreTransTree(tree, tpe) => + PreTransTree(Block(stats :+ tree)(tree.pos), tpe) + } + } + } + + def unapply(preTrans: PreTransBlock): Some[(List[Tree], PreTransLocalDef)] = + Some(preTrans.stats, preTrans.result) + } + + private sealed abstract class PreTransNoBlock extends PreTransform + + private final case class PreTransLocalDef(localDef: LocalDef)( + implicit val pos: Position) extends PreTransNoBlock { + val tpe: RefinedType = localDef.tpe + } + + private sealed abstract class PreTransGenTree extends PreTransNoBlock + + private final case class PreTransRecordTree(tree: Tree, + tpe: RefinedType, cancelFun: CancelFun) extends PreTransGenTree { + def pos = tree.pos + + assert(tree.tpe.isInstanceOf[RecordType], + s"Cannot create a PreTransRecordTree with non-record type ${tree.tpe}") + } + + private final case class PreTransTree(tree: Tree, + tpe: RefinedType) extends PreTransGenTree { + def pos: Position = tree.pos + + assert(!tree.tpe.isInstanceOf[RecordType], + s"Cannot create a Tree with record type ${tree.tpe}") + } + + private object PreTransTree { + def apply(tree: Tree): PreTransTree = + PreTransTree(tree, RefinedType(tree.tpe)) + } + + private final case class Binding(name: String, originalName: Option[String], + declaredType: Type, mutable: Boolean, value: PreTransform) + + private object NumberLiteral { + def unapply(tree: Literal): Option[Double] = tree match { + case DoubleLiteral(v) => Some(v) + case IntLiteral(v) => Some(v.toDouble) + case FloatLiteral(v) => Some(v.toDouble) + case _ => None + } + } + + private object LongFromInt { + def apply(x: Tree)(implicit pos: Position): Tree = x match { + case IntLiteral(v) => LongLiteral(v) + case _ => UnaryOp(UnaryOp.IntToLong, x) + } + + def unapply(tree: Tree): Option[Tree] = tree match { + case LongLiteral(v) if v.toInt == v => Some(IntLiteral(v.toInt)(tree.pos)) + case UnaryOp(UnaryOp.IntToLong, x) => Some(x) + case _ => None + } + } + + private object AndThen { + def apply(lhs: Tree, rhs: Tree)(implicit pos: Position): Tree = + If(lhs, rhs, BooleanLiteral(false))(BooleanType) + } + + /** Tests whether `x + y` is valid without falling out of range. */ + private def canAddLongs(x: Long, y: Long): Boolean = + if (y >= 0) x+y >= x + else x+y < x + + /** Tests whether `x - y` is valid without falling out of range. */ + private def canSubtractLongs(x: Long, y: Long): Boolean = + if (y >= 0) x-y <= x + else x-y > x + + /** Tests whether `-x` is valid without falling out of range. */ + private def canNegateLong(x: Long): Boolean = + x != Long.MinValue + + private object Intrinsics { + final val ArrayCopy = 1 + final val IdentityHashCode = ArrayCopy + 1 + + final val PropertiesOf = IdentityHashCode + 1 + + final val LongToString = PropertiesOf + 1 + final val LongCompare = LongToString + 1 + final val LongBitCount = LongCompare + 1 + final val LongSignum = LongBitCount + 1 + final val LongLeading0s = LongSignum + 1 + final val LongTrailing0s = LongLeading0s + 1 + final val LongToBinStr = LongTrailing0s + 1 + final val LongToHexStr = LongToBinStr + 1 + final val LongToOctalStr = LongToHexStr + 1 + + final val ByteArrayToInt8Array = LongToOctalStr + 1 + final val ShortArrayToInt16Array = ByteArrayToInt8Array + 1 + final val CharArrayToUint16Array = ShortArrayToInt16Array + 1 + final val IntArrayToInt32Array = CharArrayToUint16Array + 1 + final val FloatArrayToFloat32Array = IntArrayToInt32Array + 1 + final val DoubleArrayToFloat64Array = FloatArrayToFloat32Array + 1 + + final val Int8ArrayToByteArray = DoubleArrayToFloat64Array + 1 + final val Int16ArrayToShortArray = Int8ArrayToByteArray + 1 + final val Uint16ArrayToCharArray = Int16ArrayToShortArray + 1 + final val Int32ArrayToIntArray = Uint16ArrayToCharArray + 1 + final val Float32ArrayToFloatArray = Int32ArrayToIntArray + 1 + final val Float64ArrayToDoubleArray = Float32ArrayToFloatArray + 1 + + val intrinsics: Map[String, Int] = Map( + "jl_System$.arraycopy__O__I__O__I__I__V" -> ArrayCopy, + "jl_System$.identityHashCode__O__I" -> IdentityHashCode, + + "sjsr_package$.propertiesOf__sjs_js_Any__sjs_js_Array" -> PropertiesOf, + + "jl_Long$.toString__J__T" -> LongToString, + "jl_Long$.compare__J__J__I" -> LongCompare, + "jl_Long$.bitCount__J__I" -> LongBitCount, + "jl_Long$.signum__J__J" -> LongSignum, + "jl_Long$.numberOfLeadingZeros__J__I" -> LongLeading0s, + "jl_Long$.numberOfTrailingZeros__J__I" -> LongTrailing0s, + "jl_long$.toBinaryString__J__T" -> LongToBinStr, + "jl_Long$.toHexString__J__T" -> LongToHexStr, + "jl_Long$.toOctalString__J__T" -> LongToOctalStr, + + "sjs_js_typedarray_package$.byteArray2Int8Array__AB__sjs_js_typedarray_Int8Array" -> ByteArrayToInt8Array, + "sjs_js_typedarray_package$.shortArray2Int16Array__AS__sjs_js_typedarray_Int16Array" -> ShortArrayToInt16Array, + "sjs_js_typedarray_package$.charArray2Uint16Array__AC__sjs_js_typedarray_Uint16Array" -> CharArrayToUint16Array, + "sjs_js_typedarray_package$.intArray2Int32Array__AI__sjs_js_typedarray_Int32Array" -> IntArrayToInt32Array, + "sjs_js_typedarray_package$.floatArray2Float32Array__AF__sjs_js_typedarray_Float32Array" -> FloatArrayToFloat32Array, + "sjs_js_typedarray_package$.doubleArray2Float64Array__AD__sjs_js_typedarray_Float64Array" -> DoubleArrayToFloat64Array, + + "sjs_js_typedarray_package$.int8Array2ByteArray__sjs_js_typedarray_Int8Array__AB" -> Int8ArrayToByteArray, + "sjs_js_typedarray_package$.int16Array2ShortArray__sjs_js_typedarray_Int16Array__AS" -> Int16ArrayToShortArray, + "sjs_js_typedarray_package$.uint16Array2CharArray__sjs_js_typedarray_Uint16Array__AC" -> Uint16ArrayToCharArray, + "sjs_js_typedarray_package$.int32Array2IntArray__sjs_js_typedarray_Int32Array__AI" -> Int32ArrayToIntArray, + "sjs_js_typedarray_package$.float32Array2FloatArray__sjs_js_typedarray_Float32Array__AF" -> Float32ArrayToFloatArray, + "sjs_js_typedarray_package$.float64Array2DoubleArray__sjs_js_typedarray_Float64Array__AD" -> Float64ArrayToDoubleArray + ).withDefaultValue(-1) + } + + private def getIntrinsicCode(target: AbstractMethodID): Int = + Intrinsics.intrinsics(target.toString) + + private trait State[A] { + def makeBackup(): A + def restore(backup: A): Unit + } + + private class SimpleState[A](var value: A) extends State[A] { + def makeBackup(): A = value + def restore(backup: A): Unit = value = backup + } + + trait AbstractMethodID { + def inlineable: Boolean + def isTraitImplForwarder: Boolean + } + + /** Parts of [[GenIncOptimizer#MethodImpl]] with decisions about optimizations. */ + abstract class MethodImpl { + def encodedName: String + def optimizerHints: OptimizerHints + def originalDef: MethodDef + def thisType: Type + + var inlineable: Boolean = false + var isTraitImplForwarder: Boolean = false + + protected def updateInlineable(): Unit = { + val MethodDef(Ident(methodName, _), params, _, body) = originalDef + + isTraitImplForwarder = body match { + // Shape of forwarders to trait impls + case TraitImplApply(impl, method, args) => + ((args.size == params.size + 1) && + (args.head.isInstanceOf[This]) && + (args.tail.zip(params).forall { + case (VarRef(Ident(aname, _), _), + ParamDef(Ident(pname, _), _, _)) => aname == pname + case _ => false + })) + + case _ => false + } + + inlineable = optimizerHints.hasInlineAnnot || isTraitImplForwarder || { + val MethodDef(_, params, _, body) = originalDef + body match { + case _:Skip | _:This | _:Literal => true + + // Shape of accessors + case Select(This(), _, _) if params.isEmpty => true + case Assign(Select(This(), _, _), VarRef(_, _)) + if params.size == 1 => true + + // Shape of trivial call-super constructors + case Block(stats) + if params.isEmpty && isConstructorName(encodedName) && + stats.forall(isTrivialConstructorStat) => true + + // Simple method + case SimpleMethodBody() => true + + case _ => false + } + } + } + } + + private def isTrivialConstructorStat(stat: Tree): Boolean = stat match { + case This() => + true + case StaticApply(This(), _, _, Nil) => + true + case TraitImplApply(_, Ident(methodName, _), This() :: Nil) => + methodName.contains("__$init$__") + case _ => + false + } + + private object SimpleMethodBody { + @tailrec + def unapply(body: Tree): Boolean = body match { + case New(_, _, args) => areSimpleArgs(args) + case Apply(receiver, _, args) => areSimpleArgs(receiver :: args) + case StaticApply(receiver, _, _, args) => areSimpleArgs(receiver :: args) + case TraitImplApply(_, _, args) => areSimpleArgs(args) + case Select(qual, _, _) => isSimpleArg(qual) + case IsInstanceOf(inner, _) => isSimpleArg(inner) + + case Block(List(inner, Undefined())) => + unapply(inner) + + case Unbox(inner, _) => unapply(inner) + case AsInstanceOf(inner, _) => unapply(inner) + + case _ => isSimpleArg(body) + } + + private def areSimpleArgs(args: List[Tree]): Boolean = + args.forall(isSimpleArg) + + @tailrec + private def isSimpleArg(arg: Tree): Boolean = arg match { + case New(_, _, Nil) => true + case Apply(receiver, _, Nil) => isTrivialArg(receiver) + case StaticApply(receiver, _, _, Nil) => isTrivialArg(receiver) + case TraitImplApply(_, _, Nil) => true + + case ArrayLength(array) => isTrivialArg(array) + case ArraySelect(array, index) => isTrivialArg(array) && isTrivialArg(index) + + case Unbox(inner, _) => isSimpleArg(inner) + case AsInstanceOf(inner, _) => isSimpleArg(inner) + + case _ => + isTrivialArg(arg) + } + + private def isTrivialArg(arg: Tree): Boolean = arg match { + case _:VarRef | _:This | _:Literal | _:LoadModule => + true + case _ => + false + } + } + + private object BlockOrAlone { + def unapply(tree: Tree): Some[(List[Tree], Tree)] = Some(tree match { + case Block(init :+ last) => (init, last) + case _ => (Nil, tree) + }) + } + + /** Recreates precise [[Infos.MethodInfo]] from the optimized [[MethodDef]]. */ + private def recreateInfo(methodDef: MethodDef): Infos.MethodInfo = { + new RecreateInfoTraverser().recreateInfo(methodDef) + } + + private final class RecreateInfoTraverser extends Traversers.Traverser { + import RecreateInfoTraverser._ + + private val calledMethods = mutable.Map.empty[String, mutable.Set[String]] + private val calledMethodsStatic = mutable.Map.empty[String, mutable.Set[String]] + private val instantiatedClasses = mutable.Set.empty[String] + private val accessedModules = mutable.Set.empty[String] + private val accessedClassData = mutable.Set.empty[String] + + def recreateInfo(methodDef: MethodDef): Infos.MethodInfo = { + traverse(methodDef.body) + Infos.MethodInfo( + encodedName = methodDef.name.name, + calledMethods = calledMethods.toMap.mapValues(_.toList), + calledMethodsStatic = calledMethodsStatic.toMap.mapValues(_.toList), + instantiatedClasses = instantiatedClasses.toList, + accessedModules = accessedModules.toList, + accessedClassData = accessedClassData.toList) + } + + private def addCalledMethod(container: String, methodName: String): Unit = + calledMethods.getOrElseUpdate(container, mutable.Set.empty) += methodName + + private def addCalledMethodStatic(container: String, methodName: String): Unit = + calledMethodsStatic.getOrElseUpdate(container, mutable.Set.empty) += methodName + + private def refTypeToClassData(tpe: ReferenceType): String = tpe match { + case ClassType(cls) => cls + case ArrayType(base, _) => base + } + + def addAccessedClassData(encodedName: String): Unit = { + if (!AlwaysPresentClassData.contains(encodedName)) + accessedClassData += encodedName + } + + def addAccessedClassData(tpe: ReferenceType): Unit = + addAccessedClassData(refTypeToClassData(tpe)) + + override def traverse(tree: Tree): Unit = { + tree match { + case New(ClassType(cls), ctor, _) => + instantiatedClasses += cls + addCalledMethodStatic(cls, ctor.name) + + case Apply(receiver, method, _) => + receiver.tpe match { + case ClassType(cls) if !Definitions.HijackedClasses.contains(cls) => + addCalledMethod(cls, method.name) + case AnyType => + addCalledMethod(Definitions.ObjectClass, method.name) + case ArrayType(_, _) if method.name != "clone__O" => + /* clone__O is overridden in the pseudo Array classes and is + * always kept anyway, because it is in scalajsenv.js. + * Other methods delegate to Object, which we can model with + * a static call to Object.method. + */ + addCalledMethodStatic(Definitions.ObjectClass, method.name) + case _ => + // Nothing to do + } + + case StaticApply(_, ClassType(cls), method, _) => + addCalledMethodStatic(cls, method.name) + case TraitImplApply(ClassType(impl), method, _) => + addCalledMethodStatic(impl, method.name) + + case LoadModule(ClassType(cls)) => + accessedModules += cls.stripSuffix("$") + + case NewArray(tpe, _) => + addAccessedClassData(tpe) + case ArrayValue(tpe, _) => + addAccessedClassData(tpe) + case IsInstanceOf(_, cls) => + addAccessedClassData(cls) + case AsInstanceOf(_, cls) => + addAccessedClassData(cls) + case ClassOf(cls) => + addAccessedClassData(cls) + + case _ => + } + super.traverse(tree) + } + } + + private object RecreateInfoTraverser { + /** Class data that are never eliminated by dce, so we don't need to + * record them. + */ + private val AlwaysPresentClassData = { + import Definitions._ + Set("V", "Z", "C", "B", "S", "I", "J", "F", "D", + ObjectClass, StringClass) + } + } + + private def exceptionMsg(myself: AbstractMethodID, + attemptedInlining: List[AbstractMethodID]) = { + val buf = new StringBuilder() + + buf.append("The Scala.js optimizer crashed while optimizing " + myself) + + buf.append("\nMethods attempted to inline:\n") + + for (m <- attemptedInlining) { + buf.append("* ") + buf.append(m) + buf.append('\n') + } + + buf.toString + } + + private class RollbackException(val trampolineId: Int, + val savedUsedLocalNames: Set[String], + val savedUsedLabelNames: Set[String], + val savedStates: List[Any], + val cont: () => TailRec[Tree]) extends ControlThrowable + + class OptimizeException(val myself: AbstractMethodID, + val attemptedInlining: List[AbstractMethodID], cause: Throwable + ) extends Exception(exceptionMsg(myself, attemptedInlining), cause) + +} diff --git a/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/ScalaJSOptimizer.scala b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/ScalaJSOptimizer.scala new file mode 100644 index 0000000..646484b --- /dev/null +++ b/examples/scala-js/tools/shared/src/main/scala/scala/scalajs/tools/optimizer/ScalaJSOptimizer.scala @@ -0,0 +1,552 @@ +/* __ *\ +** ________ ___ / / ___ __ ____ Scala.js tools ** +** / __/ __// _ | / / / _ | __ / // __/ (c) 2013-2014, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ |/_// /_\ \ http://scala-js.org/ ** +** /____/\___/_/ |_/____/_/ | |__/ /____/ ** +** |/____/ ** +\* */ + + +package scala.scalajs.tools.optimizer + +import scala.annotation.{switch, tailrec} + +import scala.collection.mutable +import scala.collection.immutable.{Seq, Traversable} + +import java.net.URI + +import scala.scalajs.ir +import ir.Infos +import ir.ClassKind + +import scala.scalajs.tools.logging._ +import scala.scalajs.tools.io._ +import scala.scalajs.tools.classpath._ +import scala.scalajs.tools.sourcemap._ +import scala.scalajs.tools.corelib._ + +import scala.scalajs.tools.sem.Semantics + +import scala.scalajs.tools.javascript +import javascript.{Trees => js} + +/** Scala.js optimizer: does type-aware global dce. */ +class ScalaJSOptimizer( + semantics: Semantics, + optimizerFactory: (Semantics) => GenIncOptimizer) { + import ScalaJSOptimizer._ + + private val classEmitter = new javascript.ScalaJSClassEmitter(semantics) + + private[this] var persistentState: PersistentState = new PersistentState + private[this] var optimizer: GenIncOptimizer = optimizerFactory(semantics) + + def this(semantics: Semantics) = this(semantics, new IncOptimizer(_)) + + /** Applies Scala.js-specific optimizations to a CompleteIRClasspath. + * See [[ScalaJSOptimizer.Inputs]] for details about the required and + * optional inputs. + * See [[ScalaJSOptimizer.OutputConfig]] for details about the configuration + * for the output of this method. + * Returns a [[CompleteCIClasspath]] containing the result of the + * optimizations. + * + * analyzes, dead code eliminates and concatenates IR content + * - Maintains/establishes order + * - No IR in result + * - CoreJSLibs in result (since they are implicitly in the CompleteIRCP) + */ + def optimizeCP(inputs: Inputs[IRClasspath], outCfg: OutputConfig, + logger: Logger): LinkedClasspath = { + + val cp = inputs.input + + CacheUtils.cached(cp.version, outCfg.output, outCfg.cache) { + logger.info(s"Fast optimizing ${outCfg.output.path}") + optimizeIR(inputs.copy(input = inputs.input.scalaJSIR), outCfg, logger) + } + + new LinkedClasspath(cp.jsLibs, outCfg.output, cp.requiresDOM, cp.version) + } + + def optimizeIR(inputs: Inputs[Traversable[VirtualScalaJSIRFile]], + outCfg: OutputConfig, logger: Logger): Unit = { + + val builder = { + import outCfg._ + if (wantSourceMap) + new JSFileBuilderWithSourceMap(output.name, + output.contentWriter, + output.sourceMapWriter, + relativizeSourceMapBase) + else + new JSFileBuilder(output.name, output.contentWriter) + } + + builder.addLine("'use strict';") + CoreJSLibs.libs(semantics).foreach(builder.addFile _) + + optimizeIR(inputs, outCfg, builder, logger) + + builder.complete() + builder.closeWriters() + } + + def optimizeIR(inputs: Inputs[Traversable[VirtualScalaJSIRFile]], + outCfg: OptimizerConfig, builder: JSTreeBuilder, logger: Logger): Unit = { + + /* Handle tree equivalence: If we handled source maps so far, positions are + still up-to-date. Otherwise we need to flush the state if proper + positions are requested now. + */ + if (outCfg.wantSourceMap && !persistentState.wasWithSourceMap) + clean() + + persistentState.wasWithSourceMap = outCfg.wantSourceMap + + persistentState.startRun() + try { + import inputs._ + val allData = + GenIncOptimizer.logTime(logger, "Read info") { + readAllData(inputs.input, logger) + } + val (useOptimizer, refinedAnalyzer) = GenIncOptimizer.logTime( + logger, "Optimizations part") { + val analyzer = + GenIncOptimizer.logTime(logger, "Compute reachability") { + val analyzer = new Analyzer(logger, semantics, allData, + globalWarnEnabled = true, + isBeforeOptimizer = !outCfg.disableOptimizer) + analyzer.computeReachability(manuallyReachable, noWarnMissing) + analyzer + } + if (outCfg.checkIR) { + GenIncOptimizer.logTime(logger, "Check IR") { + if (analyzer.allAvailable) + checkIR(analyzer, logger) + else if (inputs.noWarnMissing.isEmpty) + sys.error("Could not check IR because there where linking errors.") + } + } + def getClassTreeIfChanged(encodedName: String, + lastVersion: Option[String]): Option[(ir.Trees.ClassDef, Option[String])] = { + val persistentFile = persistentState.encodedNameToPersistentFile(encodedName) + persistentFile.treeIfChanged(lastVersion) + } + + val useOptimizer = analyzer.allAvailable && !outCfg.disableOptimizer + + if (outCfg.batchMode) + optimizer = optimizerFactory(semantics) + + val refinedAnalyzer = if (useOptimizer) { + GenIncOptimizer.logTime(logger, "Inliner") { + optimizer.update(analyzer, getClassTreeIfChanged, + outCfg.wantSourceMap, logger) + } + GenIncOptimizer.logTime(logger, "Refined reachability analysis") { + val refinedData = computeRefinedData(allData, optimizer) + val refinedAnalyzer = new Analyzer(logger, semantics, refinedData, + globalWarnEnabled = false, + isBeforeOptimizer = false) + refinedAnalyzer.computeReachability(manuallyReachable, noWarnMissing) + refinedAnalyzer + } + } else { + if (inputs.noWarnMissing.isEmpty && !outCfg.disableOptimizer) + logger.warn("Not running the inliner because there where linking errors.") + analyzer + } + (useOptimizer, refinedAnalyzer) + } + GenIncOptimizer.logTime(logger, "Write DCE'ed output") { + buildDCEedOutput(builder, refinedAnalyzer, useOptimizer) + } + } finally { + persistentState.endRun(outCfg.unCache) + logger.debug( + s"Inc. opt stats: reused: ${persistentState.statsReused} -- "+ + s"invalidated: ${persistentState.statsInvalidated} -- "+ + s"trees read: ${persistentState.statsTreesRead}") + } + } + + /** Resets all persistent state of this optimizer */ + def clean(): Unit = { + persistentState = new PersistentState + optimizer = optimizerFactory(semantics) + } + + private def readAllData(ir: Traversable[VirtualScalaJSIRFile], + logger: Logger): scala.collection.Seq[Infos.ClassInfo] = { + ir.map(persistentState.getPersistentIRFile(_).info).toSeq + } + + private def checkIR(analyzer: Analyzer, logger: Logger): Unit = { + val allClassDefs = for { + classInfo <- analyzer.classInfos.values + persistentIRFile <- persistentState.encodedNameToPersistentFile.get( + classInfo.encodedName) + } yield persistentIRFile.tree + val checker = new IRChecker(analyzer, allClassDefs.toSeq, logger) + if (!checker.check()) + sys.error(s"There were ${checker.errorCount} IR checking errors.") + } + + private def computeRefinedData( + allData: scala.collection.Seq[Infos.ClassInfo], + optimizer: GenIncOptimizer): scala.collection.Seq[Infos.ClassInfo] = { + + def refineMethodInfo(container: optimizer.MethodContainer, + methodInfo: Infos.MethodInfo): Infos.MethodInfo = { + container.methods.get(methodInfo.encodedName).fold(methodInfo) { + methodImpl => methodImpl.preciseInfo + } + } + + def refineMethodInfos(container: optimizer.MethodContainer, + methodInfos: List[Infos.MethodInfo]): List[Infos.MethodInfo] = { + methodInfos.map(m => refineMethodInfo(container, m)) + } + + def refineClassInfo(container: optimizer.MethodContainer, + info: Infos.ClassInfo): Infos.ClassInfo = { + val refinedMethods = refineMethodInfos(container, info.methods) + Infos.ClassInfo(info.name, info.encodedName, info.isExported, + info.ancestorCount, info.kind, info.superClass, info.ancestors, + Infos.OptimizerHints.empty, refinedMethods) + } + + for { + info <- allData + } yield { + info.kind match { + case ClassKind.Class | ClassKind.ModuleClass => + optimizer.getClass(info.encodedName).fold(info) { + cls => refineClassInfo(cls, info) + } + + case ClassKind.TraitImpl => + optimizer.getTraitImpl(info.encodedName).fold(info) { + impl => refineClassInfo(impl, info) + } + + case _ => + info + } + } + } + + private def buildDCEedOutput(builder: JSTreeBuilder, + analyzer: Analyzer, useInliner: Boolean): Unit = { + + def compareClassInfo(lhs: analyzer.ClassInfo, rhs: analyzer.ClassInfo) = { + if (lhs.ancestorCount != rhs.ancestorCount) lhs.ancestorCount < rhs.ancestorCount + else lhs.encodedName.compareTo(rhs.encodedName) < 0 + } + + def addPersistentFile(classInfo: analyzer.ClassInfo, + persistentFile: PersistentIRFile) = { + import ir.Trees._ + import javascript.JSDesugaring.{desugarJavaScript => desugar} + + val d = persistentFile.desugared + lazy val classDef = { + persistentState.statsTreesRead += 1 + persistentFile.tree + } + + def addTree(tree: js.Tree): Unit = + builder.addJSTree(tree) + + def addReachableMethods(emitFun: (String, MethodDef) => js.Tree): Unit = { + /* This is a bit convoluted because we have to: + * * avoid to use classDef at all if we already know all the needed methods + * * if any new method is needed, better to go through the defs once + */ + val methodNames = d.methodNames.getOrElseUpdate( + classDef.defs collect { + case MethodDef(Ident(encodedName, _), _, _, _) => encodedName + }) + val reachableMethods = methodNames.filter( + name => classInfo.methodInfos(name).isReachable) + if (reachableMethods.forall(d.methods.contains(_))) { + for (encodedName <- reachableMethods) { + addTree(d.methods(encodedName)) + } + } else { + classDef.defs.foreach { + case m: MethodDef if m.name.isInstanceOf[Ident] => + if (classInfo.methodInfos(m.name.name).isReachable) { + addTree(d.methods.getOrElseUpdate(m.name.name, + emitFun(classInfo.encodedName, m))) + } + case _ => + } + } + } + + if (classInfo.isImplClass) { + if (useInliner) { + for { + method <- optimizer.findTraitImpl(classInfo.encodedName).methods.values + if (classInfo.methodInfos(method.encodedName).isReachable) + } { + addTree(method.desugaredDef) + } + } else { + addReachableMethods(classEmitter.genTraitImplMethod) + } + } else if (!classInfo.hasMoreThanData) { + // there is only the data anyway + addTree(d.wholeClass.getOrElseUpdate( + classEmitter.genClassDef(classDef))) + } else { + if (classInfo.isAnySubclassInstantiated) { + addTree(d.constructor.getOrElseUpdate( + classEmitter.genConstructor(classDef))) + if (useInliner) { + for { + method <- optimizer.findClass(classInfo.encodedName).methods.values + if (classInfo.methodInfos(method.encodedName).isReachable) + } { + addTree(method.desugaredDef) + } + } else { + addReachableMethods(classEmitter.genMethod) + } + addTree(d.exportedMembers.getOrElseUpdate(js.Block { + classDef.defs collect { + case m: MethodDef if m.name.isInstanceOf[StringLiteral] => + classEmitter.genMethod(classInfo.encodedName, m) + case p: PropertyDef => + classEmitter.genProperty(classInfo.encodedName, p) + } + }(classDef.pos))) + } + if (classInfo.isDataAccessed) { + addTree(d.typeData.getOrElseUpdate(js.Block( + classEmitter.genInstanceTests(classDef), + classEmitter.genArrayInstanceTests(classDef), + classEmitter.genTypeData(classDef) + )(classDef.pos))) + } + if (classInfo.isAnySubclassInstantiated) + addTree(d.setTypeData.getOrElseUpdate( + classEmitter.genSetTypeData(classDef))) + if (classInfo.isModuleAccessed) + addTree(d.moduleAccessor.getOrElseUpdate( + classEmitter.genModuleAccessor(classDef))) + addTree(d.classExports.getOrElseUpdate( + classEmitter.genClassExports(classDef))) + } + } + + + for { + classInfo <- analyzer.classInfos.values.toSeq.sortWith(compareClassInfo) + if classInfo.isNeededAtAll + } { + val optPersistentFile = + persistentState.encodedNameToPersistentFile.get(classInfo.encodedName) + + // if we have a persistent file, this is not a dummy class + optPersistentFile.fold { + if (classInfo.isAnySubclassInstantiated) { + // Subclass will emit constructor that references this dummy class. + // Therefore, we need to emit a dummy parent. + builder.addJSTree( + classEmitter.genDummyParent(classInfo.encodedName)) + } + } { pf => addPersistentFile(classInfo, pf) } + } + } +} + +object ScalaJSOptimizer { + /** Inputs of the Scala.js optimizer. */ + final case class Inputs[T]( + /** The CompleteNCClasspath or the IR files to be packaged. */ + input: T, + /** Manual additions to reachability */ + manuallyReachable: Seq[ManualReachability] = Nil, + /** Elements we won't warn even if they don't exist */ + noWarnMissing: Seq[NoWarnMissing] = Nil + ) + + sealed abstract class ManualReachability + final case class ReachObject(name: String) extends ManualReachability + final case class Instantiate(name: String) extends ManualReachability + final case class ReachMethod(className: String, methodName: String, + static: Boolean) extends ManualReachability + + sealed abstract class NoWarnMissing + final case class NoWarnClass(className: String) extends NoWarnMissing + final case class NoWarnMethod(className: String, methodName: String) + extends NoWarnMissing + + /** Configurations relevant to the optimizer */ + trait OptimizerConfig { + /** Ask to produce source map for the output. Is used in the incremental + * optimizer to decide whether a position change should trigger re-inlining + */ + val wantSourceMap: Boolean + /** If true, performs expensive checks of the IR for the used parts. */ + val checkIR: Boolean + /** If true, the optimizer removes trees that have not been used in the + * last run from the cache. Otherwise, all trees that has been used once, + * are kept in memory. */ + val unCache: Boolean + /** If true, no optimizations are performed */ + val disableOptimizer: Boolean + /** If true, nothing is performed incrementally */ + val batchMode: Boolean + } + + /** Configuration for the output of the Scala.js optimizer. */ + final case class OutputConfig( + /** Writer for the output. */ + output: WritableVirtualJSFile, + /** Cache file */ + cache: Option[WritableVirtualTextFile] = None, + /** Ask to produce source map for the output */ + wantSourceMap: Boolean = false, + /** Base path to relativize paths in the source map. */ + relativizeSourceMapBase: Option[URI] = None, + /** If true, performs expensive checks of the IR for the used parts. */ + checkIR: Boolean = false, + /** If true, the optimizer removes trees that have not been used in the + * last run from the cache. Otherwise, all trees that has been used once, + * are kept in memory. */ + unCache: Boolean = true, + /** If true, no optimizations are performed */ + disableOptimizer: Boolean = false, + /** If true, nothing is performed incrementally */ + batchMode: Boolean = false + ) extends OptimizerConfig + + // Private helpers ----------------------------------------------------------- + + private final class PersistentState { + val files = mutable.Map.empty[String, PersistentIRFile] + val encodedNameToPersistentFile = + mutable.Map.empty[String, PersistentIRFile] + + var statsReused: Int = 0 + var statsInvalidated: Int = 0 + var statsTreesRead: Int = 0 + + var wasWithSourceMap: Boolean = true + + def startRun(): Unit = { + statsReused = 0 + statsInvalidated = 0 + statsTreesRead = 0 + for (file <- files.values) + file.startRun() + } + + def getPersistentIRFile(irFile: VirtualScalaJSIRFile): PersistentIRFile = { + val file = files.getOrElseUpdate(irFile.path, + new PersistentIRFile(irFile.path)) + if (file.updateFile(irFile)) + statsReused += 1 + else + statsInvalidated += 1 + encodedNameToPersistentFile += ((file.info.encodedName, file)) + file + } + + def endRun(unCache: Boolean): Unit = { + // "Garbage-collect" persisted versions of files that have disappeared + files.retain((_, f) => f.cleanAfterRun(unCache)) + encodedNameToPersistentFile.clear() + } + } + + private final class PersistentIRFile(val path: String) { + import ir.Trees._ + + private[this] var existedInThisRun: Boolean = false + private[this] var desugaredUsedInThisRun: Boolean = false + + private[this] var irFile: VirtualScalaJSIRFile = null + private[this] var version: Option[String] = None + private[this] var _info: Infos.ClassInfo = null + private[this] var _tree: ClassDef = null + private[this] var _desugared: Desugared = null + + def startRun(): Unit = { + existedInThisRun = false + desugaredUsedInThisRun = false + } + + def updateFile(irFile: VirtualScalaJSIRFile): Boolean = { + existedInThisRun = true + this.irFile = irFile + if (version.isDefined && version == irFile.version) { + // yeepeeh, nothing to do + true + } else { + version = irFile.version + _info = irFile.info + _tree = null + _desugared = null + false + } + } + + def info: Infos.ClassInfo = _info + + def desugared: Desugared = { + desugaredUsedInThisRun = true + if (_desugared == null) + _desugared = new Desugared + _desugared + } + + def tree: ClassDef = { + if (_tree == null) + _tree = irFile.tree + _tree + } + + def treeIfChanged(lastVersion: Option[String]): Option[(ClassDef, Option[String])] = { + if (lastVersion.isDefined && lastVersion == version) None + else Some((tree, version)) + } + + /** Returns true if this file should be kept for the next run at all. */ + def cleanAfterRun(unCache: Boolean): Boolean = { + irFile = null + if (unCache && !desugaredUsedInThisRun) + _desugared = null // free desugared if unused in this run + existedInThisRun + } + } + + private final class Desugared { + // for class kinds that are not decomposed + val wholeClass = new OneTimeCache[js.Tree] + + val constructor = new OneTimeCache[js.Tree] + val methodNames = new OneTimeCache[List[String]] + val methods = mutable.Map.empty[String, js.Tree] + val exportedMembers = new OneTimeCache[js.Tree] + val typeData = new OneTimeCache[js.Tree] + val setTypeData = new OneTimeCache[js.Tree] + val moduleAccessor = new OneTimeCache[js.Tree] + val classExports = new OneTimeCache[js.Tree] + } + + private final class OneTimeCache[A >: Null] { + private[this] var value: A = null + def getOrElseUpdate(v: => A): A = { + if (value == null) + value = v + value + } + } +} |