diff options
author | Davies Liu <davies@databricks.com> | 2015-06-30 10:48:49 -0700 |
---|---|---|
committer | Davies Liu <davies@databricks.com> | 2015-06-30 10:48:49 -0700 |
commit | fbb267ed6fe799a58f88c2fba2d41e954e5f1547 (patch) | |
tree | 19f499236591e938de23399880a9e42a55a0ad46 /sql | |
parent | 5fa0863626aaf5a9a41756a0b1ec82bddccbf067 (diff) | |
download | spark-fbb267ed6fe799a58f88c2fba2d41e954e5f1547.tar.gz spark-fbb267ed6fe799a58f88c2fba2d41e954e5f1547.tar.bz2 spark-fbb267ed6fe799a58f88c2fba2d41e954e5f1547.zip |
[SPARK-8713] Make codegen thread safe
Codegen takes three steps:
1. Take a list of expressions, convert them into Java source code and a list of expressions that don't not support codegen (fallback to interpret mode).
2. Compile the Java source into Java class (bytecode)
3. Using the Java class and the list of expression to build a Projection.
Currently, we cache the whole three steps, the key is a list of expression, result is projection. Because some of expressions (which may not thread-safe, for example, Random) will be hold by the Projection, the projection maybe not thread safe.
This PR change to only cache the second step, then we can build projection using codegen even some expressions are not thread-safe, because the cache will not hold any expression anymore.
cc marmbrus rxin JoshRosen
Author: Davies Liu <davies@databricks.com>
Closes #7101 from davies/codegen_safe and squashes the following commits:
7dd41f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe
847bd08 [Davies Liu] don't use scala.refect
4ddaaed [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe
1793cf1 [Davies Liu] make codegen thread safe
Diffstat (limited to 'sql')
12 files changed, 24 insertions, 54 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aed48921bd..b5063f32fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -61,14 +61,6 @@ abstract class Expression extends TreeNode[Expression] { def eval(input: InternalRow = null): Any /** - * Return true if this expression is thread-safe, which means it could be used by multiple - * threads in the same time. - * - * An expression that is not thread-safe can not be cached and re-used, especially for codegen. - */ - def isThreadSafe: Boolean = true - - /** * Returns an [[GeneratedExpressionCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. * @@ -76,9 +68,6 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - if (!isThreadSafe) { - throw new Exception(s"$this is not thread-safe, can not be used in codegen") - } val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) @@ -178,8 +167,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" - override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe - /** * Short hand for generating binary evaluation code. * If either of the sub-expressions is null, the result of this computation @@ -237,7 +224,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def isThreadSafe: Boolean = child.isThreadSafe /** * Called by unary expressions to generate a code block that returns null if its parent returns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index dbb4381d54..ebabb6f117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -956,7 +956,4 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:on private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) - - // TODO(davies): make ScalaUDF work with codegen - override def isThreadSafe: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf6a6a1240..a64027e48a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -235,11 +235,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * Compile the Java source code into a Java class, using Janino. - * - * It will track the time used to compile */ protected def compile(code: String): GeneratedClass = { - val startTime = System.nanoTime() + cache.get(code) + } + + /** + * Compile the Java source code into a Java class, using Janino. + */ + private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) @@ -251,9 +255,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin logError(s"failed to compile:\n $code", e) throw e } - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } @@ -266,16 +267,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * automatically, in order to constrain its memory footprint. Note that this cache does not use * weak keys/values and thus does not respond to memory pressure. */ - protected val cache = CacheBuilder.newBuilder() + private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = { + new CacheLoader[String, GeneratedClass]() { + override def load(code: String): GeneratedClass = { val startTime = System.nanoTime() - val result = create(in) + val result = doCompile(code) val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logInfo(s"Code generated expression $in in $timeMs ms") + logInfo(s"Code generated in $timeMs ms") result } }) @@ -285,7 +286,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin generate(bind(expressions, inputSchema)) /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def generate(expressions: InType): OutType = create(canonicalize(expressions)) /** * Create a new codegen context for expression evaluator, used to store those diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 7ed2c5adde..97cb16045a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -38,7 +38,6 @@ class BaseOrdering extends Ordering[InternalRow] { */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) @@ -47,8 +46,6 @@ object GenerateOrdering in.map(BindReferences.bindReference(_, inputSchema)) protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { - val a = newTermName("a") - val b = newTermName("b") val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => @@ -56,9 +53,9 @@ object GenerateOrdering val evalB = order.child.gen(ctx) val asc = order.direction == Ascending s""" - i = $a; + i = a; ${evalA.code} - i = $b; + i = b; ${evalB.code} if (${evalA.isNull} && ${evalB.isNull}) { // Nothing @@ -80,7 +77,7 @@ object GenerateOrdering return new SpecificOrdering(expr); } - class SpecificOrdering extends ${typeOf[BaseOrdering]} { + class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 39d32b78cc..5be47175fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -32,7 +32,6 @@ abstract class BaseProject extends Projection {} * primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -157,7 +156,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${typeOf[BaseProject]} { + class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; public SpecificProjection($exprType[] expr) { @@ -170,7 +169,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - final class SpecificRow extends ${typeOf[MutableRow]} { + final class SpecificRow extends ${classOf[MutableRow].getName} { $columns @@ -224,7 +223,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public InternalRow copy() { Object[] arr = new Object[${expressions.length}]; ${copyColumns} - return new ${typeOf[GenericInternalRow]}(arr); + return new ${classOf[GenericInternalRow].getName}(arr); } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 6f56a9ec7b..81ebda3060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -117,8 +117,6 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) - override def isThreadSafe: Boolean = child.isThreadSafe - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5d5911403e..78be282434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -51,8 +51,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def isThreadSafe: Boolean = children.forall(_.isThreadSafe) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = true; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 47f56b2b7e..7739a9f949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { + if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) @@ -168,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled && expressions.forall(_.isThreadSafe)) { + if(codegenEnabled) { GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) @@ -178,7 +178,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled && expression.isThreadSafe) { + if (codegenEnabled) { GeneratePredicate.generate(expression, inputSchema) } else { InterpretedPredicate.create(expression, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 3b217348b7..68914cf85c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -48,6 +48,4 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { count += 1 (TaskContext.get().partitionId().toLong << 33) + currentCount } - - override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 54c8eeb41a..42b51caab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -270,7 +270,7 @@ private[sql] case class InsertIntoHadoopFsRelation( inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { + if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7005c7079a..0b875304f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -591,7 +591,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio rdd.map(_.asInstanceOf[InternalRow]) } converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) { + val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d7827d56ca..4dea561ae5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -120,8 +120,6 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - override def isThreadSafe: Boolean = false - // TODO: Finish input output types. override def eval(input: InternalRow): Any = { unwrap( @@ -180,8 +178,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def isThreadSafe: Boolean = false - override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. |