aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala65
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala4
3 files changed, 57 insertions, 14 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index e8e42d72d4..52c2971b73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -334,7 +334,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
- // println(s"${CodeFormatter.format(cleanedSource)}")
+ logDebug(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6ebbc8be6f..6e2a5aa4f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -74,8 +74,27 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
}
-case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode with CodegenSupport {
- override def output: Seq[Attribute] = child.output
+case class Filter(condition: Expression, child: SparkPlan)
+ extends UnaryNode with CodegenSupport with PredicateHelper {
+
+ // Split out all the IsNotNulls from condition.
+ private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
+ case IsNotNull(a) if child.output.contains(a) => true
+ case _ => false
+ }
+
+ // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
+ private val notNullAttributes = notNullPreds.flatMap(_.references)
+
+ override def output: Seq[Attribute] = {
+ child.output.map { a =>
+ if (a.nullable && notNullAttributes.contains(a)) {
+ a.withNullability(false)
+ } else {
+ a
+ }
+ }
+ }
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -90,20 +109,42 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- val expr = ExpressionCanonicalizer.execute(
- BindReferences.bindReference(condition, child.output))
+
+ // filter out the nulls
+ val filterOutNull = notNullAttributes.map { a =>
+ val idx = child.output.indexOf(a)
+ s"if (${input(idx).isNull}) continue;"
+ }.mkString("\n")
+
ctx.currentVars = input
- val eval = expr.gen(ctx)
- val nullCheck = if (expr.nullable) {
- s"!${eval.isNull} &&"
- } else {
- s""
+ val predicates = otherPreds.map { e =>
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(e, output))
+ val ev = bound.gen(ctx)
+ val nullCheck = if (bound.nullable) {
+ s"${ev.isNull} || "
+ } else {
+ s""
+ }
+ s"""
+ |${ev.code}
+ |if (${nullCheck}!${ev.value}) continue;
+ """.stripMargin
+ }.mkString("\n")
+
+ // Reset the isNull to false for the not-null columns, then the followed operators could
+ // generate better code (remove dead branches).
+ val resultVars = input.zipWithIndex.map { case (ev, i) =>
+ if (notNullAttributes.contains(child.output(i))) {
+ ev.isNull = "false"
+ }
+ ev
}
s"""
- |${eval.code}
- |if (!($nullCheck ${eval.value})) continue;
+ |$filterOutNull
+ |$predicates
|$numOutput.add(1);
- |${consume(ctx, ctx.currentVars)}
+ |${consume(ctx, resultVars)}
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index d83486df02..4143e944e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -55,7 +55,9 @@ case class BroadcastNestedLoopJoin(
UnsafeProjection.create(output, output)
} else {
// Always put the stream side on left to simplify implementation
- UnsafeProjection.create(output, streamed.output ++ broadcast.output)
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (streamed.output ++ broadcast.output).map(_.withNullability(true)))
}
}