aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-30 10:48:49 -0700
committerDavies Liu <davies@databricks.com>2015-06-30 10:48:49 -0700
commitfbb267ed6fe799a58f88c2fba2d41e954e5f1547 (patch)
tree19f499236591e938de23399880a9e42a55a0ad46 /sql
parent5fa0863626aaf5a9a41756a0b1ec82bddccbf067 (diff)
downloadspark-fbb267ed6fe799a58f88c2fba2d41e954e5f1547.tar.gz
spark-fbb267ed6fe799a58f88c2fba2d41e954e5f1547.tar.bz2
spark-fbb267ed6fe799a58f88c2fba2d41e954e5f1547.zip
[SPARK-8713] Make codegen thread safe
Codegen takes three steps: 1. Take a list of expressions, convert them into Java source code and a list of expressions that don't not support codegen (fallback to interpret mode). 2. Compile the Java source into Java class (bytecode) 3. Using the Java class and the list of expression to build a Projection. Currently, we cache the whole three steps, the key is a list of expression, result is projection. Because some of expressions (which may not thread-safe, for example, Random) will be hold by the Projection, the projection maybe not thread safe. This PR change to only cache the second step, then we can build projection using codegen even some expressions are not thread-safe, because the cache will not hold any expression anymore. cc marmbrus rxin JoshRosen Author: Davies Liu <davies@databricks.com> Closes #7101 from davies/codegen_safe and squashes the following commits: 7dd41f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe 847bd08 [Davies Liu] don't use scala.refect 4ddaaed [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe 1793cf1 [Davies Liu] make codegen thread safe
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala4
12 files changed, 24 insertions, 54 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index aed48921bd..b5063f32fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -61,14 +61,6 @@ abstract class Expression extends TreeNode[Expression] {
def eval(input: InternalRow = null): Any
/**
- * Return true if this expression is thread-safe, which means it could be used by multiple
- * threads in the same time.
- *
- * An expression that is not thread-safe can not be cached and re-used, especially for codegen.
- */
- def isThreadSafe: Boolean = true
-
- /**
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
*
@@ -76,9 +68,6 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
- if (!isThreadSafe) {
- throw new Exception(s"$this is not thread-safe, can not be used in codegen")
- }
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
@@ -178,8 +167,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def toString: String = s"($left $symbol $right)"
- override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe
-
/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
@@ -237,7 +224,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
- override def isThreadSafe: Boolean = child.isThreadSafe
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index dbb4381d54..ebabb6f117 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -956,7 +956,4 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi
// scalastyle:on
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
override def eval(input: InternalRow): Any = converter(f(input))
-
- // TODO(davies): make ScalaUDF work with codegen
- override def isThreadSafe: Boolean = false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index bf6a6a1240..a64027e48a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -235,11 +235,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/**
* Compile the Java source code into a Java class, using Janino.
- *
- * It will track the time used to compile
*/
protected def compile(code: String): GeneratedClass = {
- val startTime = System.nanoTime()
+ cache.get(code)
+ }
+
+ /**
+ * Compile the Java source code into a Java class, using Janino.
+ */
+ private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(getClass.getClassLoader)
evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow"))
@@ -251,9 +255,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
logError(s"failed to compile:\n $code", e)
throw e
}
- val endTime = System.nanoTime()
- def timeMs: Double = (endTime - startTime).toDouble / 1000000
- logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms")
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}
@@ -266,16 +267,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* automatically, in order to constrain its memory footprint. Note that this cache does not use
* weak keys/values and thus does not respond to memory pressure.
*/
- protected val cache = CacheBuilder.newBuilder()
+ private val cache = CacheBuilder.newBuilder()
.maximumSize(100)
.build(
- new CacheLoader[InType, OutType]() {
- override def load(in: InType): OutType = {
+ new CacheLoader[String, GeneratedClass]() {
+ override def load(code: String): GeneratedClass = {
val startTime = System.nanoTime()
- val result = create(in)
+ val result = doCompile(code)
val endTime = System.nanoTime()
def timeMs: Double = (endTime - startTime).toDouble / 1000000
- logInfo(s"Code generated expression $in in $timeMs ms")
+ logInfo(s"Code generated in $timeMs ms")
result
}
})
@@ -285,7 +286,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
generate(bind(expressions, inputSchema))
/** Generates the requested evaluator given already bound expression(s). */
- def generate(expressions: InType): OutType = cache.get(canonicalize(expressions))
+ def generate(expressions: InType): OutType = create(canonicalize(expressions))
/**
* Create a new codegen context for expression evaluator, used to store those
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 7ed2c5adde..97cb16045a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -38,7 +38,6 @@ class BaseOrdering extends Ordering[InternalRow] {
*/
object GenerateOrdering
extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging {
- import scala.reflect.runtime.universe._
protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
@@ -47,8 +46,6 @@ object GenerateOrdering
in.map(BindReferences.bindReference(_, inputSchema))
protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = {
- val a = newTermName("a")
- val b = newTermName("b")
val ctx = newCodeGenContext()
val comparisons = ordering.zipWithIndex.map { case (order, i) =>
@@ -56,9 +53,9 @@ object GenerateOrdering
val evalB = order.child.gen(ctx)
val asc = order.direction == Ascending
s"""
- i = $a;
+ i = a;
${evalA.code}
- i = $b;
+ i = b;
${evalB.code}
if (${evalA.isNull} && ${evalB.isNull}) {
// Nothing
@@ -80,7 +77,7 @@ object GenerateOrdering
return new SpecificOrdering(expr);
}
- class SpecificOrdering extends ${typeOf[BaseOrdering]} {
+ class SpecificOrdering extends ${classOf[BaseOrdering].getName} {
private $exprType[] expressions = null;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 39d32b78cc..5be47175fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -32,7 +32,6 @@ abstract class BaseProject extends Projection {}
* primitive values.
*/
object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
- import scala.reflect.runtime.universe._
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)
@@ -157,7 +156,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
return new SpecificProjection(expr);
}
- class SpecificProjection extends ${typeOf[BaseProject]} {
+ class SpecificProjection extends ${classOf[BaseProject].getName} {
private $exprType[] expressions = null;
public SpecificProjection($exprType[] expr) {
@@ -170,7 +169,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
}
- final class SpecificRow extends ${typeOf[MutableRow]} {
+ final class SpecificRow extends ${classOf[MutableRow].getName} {
$columns
@@ -224,7 +223,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
public InternalRow copy() {
Object[] arr = new Object[${expressions.length}];
${copyColumns}
- return new ${typeOf[GenericInternalRow]}(arr);
+ return new ${classOf[GenericInternalRow].getName}(arr);
}
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 6f56a9ec7b..81ebda3060 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -117,8 +117,6 @@ case class Alias(child: Expression, name: String)(
override def eval(input: InternalRow): Any = child.eval(input)
- override def isThreadSafe: Boolean = child.isThreadSafe
-
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override def dataType: DataType = child.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 5d5911403e..78be282434 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -51,8 +51,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
result
}
- override def isThreadSafe: Boolean = children.forall(_.isThreadSafe)
-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
s"""
boolean ${ev.isNull} = true;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 47f56b2b7e..7739a9f949 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
+ if (codegenEnabled) {
GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
@@ -168,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
inputSchema: Seq[Attribute]): () => MutableProjection = {
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
+ if(codegenEnabled) {
GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
@@ -178,7 +178,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
- if (codegenEnabled && expression.isThreadSafe) {
+ if (codegenEnabled) {
GeneratePredicate.generate(expression, inputSchema)
} else {
InterpretedPredicate.create(expression, inputSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
index 3b217348b7..68914cf85c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
@@ -48,6 +48,4 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
count += 1
(TaskContext.get().partitionId().toLong << 33) + currentCount
}
-
- override def isThreadSafe: Boolean = false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 54c8eeb41a..42b51caab5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -270,7 +270,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
inputSchema: Seq[Attribute]): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
+ if (codegenEnabled) {
GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 7005c7079a..0b875304f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -591,7 +591,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
rdd.map(_.asInstanceOf[InternalRow])
}
converted.mapPartitions { rows =>
- val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) {
+ val buildProjection = if (codegenEnabled) {
GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes)
} else {
() => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index d7827d56ca..4dea561ae5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -120,8 +120,6 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre
@transient
protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
- override def isThreadSafe: Boolean = false
-
// TODO: Finish input output types.
override def eval(input: InternalRow): Any = {
unwrap(
@@ -180,8 +178,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr
lazy val dataType: DataType = inspectorToDataType(returnInspector)
- override def isThreadSafe: Boolean = false
-
override def eval(input: InternalRow): Any = {
returnInspector // Make sure initialized.