aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-08-05 11:38:56 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-05 11:38:56 -0700
commite1e05873fc75781b6dd3f7fadbfb57824f83054e (patch)
tree58f80041e2174bba31dbccaf4417120002fa2922 /sql
parent1f8c364b9c6636f06986f5f80d5a49b7a7772ac3 (diff)
downloadspark-e1e05873fc75781b6dd3f7fadbfb57824f83054e.tar.gz
spark-e1e05873fc75781b6dd3f7fadbfb57824f83054e.tar.bz2
spark-e1e05873fc75781b6dd3f7fadbfb57824f83054e.zip
[SPARK-9403] [SQL] Add codegen support in In and InSet
This continues tarekauel's work in #7778. Author: Liang-Chi Hsieh <viirya@appier.com> Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7893 from viirya/codegen_in and squashes the following commits: 81ff97b [Liang-Chi Hsieh] For comments. 47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in cf4bf41 [Liang-Chi Hsieh] For comments. f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 446bbcd [Liang-Chi Hsieh] Fix bug. b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test. 224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types. 86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold. b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala63
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala6
6 files changed, 119 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b69bbabee7..68c832d719 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -97,32 +100,80 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
-case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
+case class In(value: Expression, list: Seq[Expression]) extends Predicate
+ with ImplicitCastInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (list.exists(l => l.dataType != value.dataType)) {
+ TypeCheckResult.TypeCheckFailure(
+ "Arguments must be same type")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
override def children: Seq[Expression] = value +: list
- override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+ override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
override def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
list.exists(e => e.eval(input) == evaluatedValue)
}
-}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val valueGen = value.gen(ctx)
+ val listGen = list.map(_.gen(ctx))
+ val listCode = listGen.map(x =>
+ s"""
+ if (!${ev.primitive}) {
+ ${x.code}
+ if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
+ ${ev.primitive} = true;
+ }
+ }
+ """).mkString("\n")
+ s"""
+ ${valueGen.code}
+ boolean ${ev.primitive} = false;
+ boolean ${ev.isNull} = false;
+ $listCode
+ """
+ }
+}
/**
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
-case class InSet(child: Expression, hset: Set[Any])
- extends UnaryExpression with Predicate with CodegenFallback {
+case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
- override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+ override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
override def eval(input: InternalRow): Any = {
hset.contains(child.eval(input))
}
+
+ def getHSet(): Set[Any] = hset
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val setName = classOf[Set[Any]].getName
+ val InSetName = classOf[InSet].getName
+ val childGen = child.gen(ctx)
+ ctx.references += this
+ val hsetTerm = ctx.freshName("hset")
+ ctx.addMutableState(setName, hsetTerm,
+ s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
+ s"""
+ ${childGen.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
+ """
+ }
}
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29d706dcb3..4ab5ac2c61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
- case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+ case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 =>
val hSet = list.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index d7eb13c50b..7beef71845 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.types._
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
+
+ val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+ LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+ primitiveTypes.map { t =>
+ val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val inputData = Seq.fill(10) {
+ val value = dataGen.apply()
+ value match {
+ case d: Double if d.isNaN => 0.0d
+ case f: Float if f.isNaN => 0.0f
+ case _ => value
+ }
+ }
+ val input = inputData.map(Literal(_))
+ checkEvaluation(In(input(0), input.slice(1, 10)),
+ inputData.slice(1, 10).contains(inputData(0)))
+ }
}
test("INSET") {
@@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), false)
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
+
+ val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+ LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+ primitiveTypes.map { t =>
+ val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val inputData = Seq.fill(10) {
+ val value = dataGen.apply()
+ value match {
+ case d: Double if d.isNaN => 0.0d
+ case f: Float if f.isNaN => 0.0f
+ case _ => value
+ }
+ }
+ val input = inputData.map(Literal(_))
+ checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
+ inputData.slice(1, 10).contains(inputData(0)))
+ }
}
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 1d433275fe..6f7b5b9572 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
- test("OptimizedIn test: In clause optimized to InSet") {
+ test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
.analyze
val optimized = Optimize.execute(originalQuery.analyze)
+ comparePlans(optimized, originalQuery)
+ }
+
+ test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
+ .analyze
+
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
- .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
+ .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a43bccbe69..e5dc676b87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
+ // Because we only convert In to InSet in Optimizer when there are more than certain
+ // items. So it is possible we still get an In expression here that needs to be pushed
+ // down.
+ case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+ val hSet = list.map(e => e.eval(EmptyRow))
+ Some(sources.In(a.name, hSet.toArray))
+
case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
case expressions.IsNotNull(a: Attribute) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 35ca0b4c7c..b351380373 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".in("z", "y")),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
+
+ val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
+
+ intercept[AnalysisException] {
+ df2.filter($"a".in($"b"))
+ }
}
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(