aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-25 12:10:02 -0700
committerReynold Xin <rxin@databricks.com>2015-07-25 12:10:02 -0700
commit2c94d0f24a37fa079b56d534b0b0a4574209215b (patch)
tree0aa5ca6ff2c9bb9728bfb6445bd4bfeb4cacf7b7
parente2ec018e37cb699077b5fa2bd662f2055cb42296 (diff)
downloadspark-2c94d0f24a37fa079b56d534b0b0a4574209215b.tar.gz
spark-2c94d0f24a37fa079b56d534b0b0a4574209215b.tar.bz2
spark-2c94d0f24a37fa079b56d534b0b0a4574209215b.zip
[SPARK-9192][SQL] add initialization phase for nondeterministic expression
Currently nondeterministic expression is broken without a explicit initialization phase. Let me take `MonotonicallyIncreasingID` as an example. This expression need a mutable state to remember how many times it has been evaluated, so we use `transient var count: Long` there. By being transient, the `count` will be reset to 0 and **only** to 0 when serialize and deserialize it, as deserialize transient variable will result to default value. There is *no way* to use another initial value for `count`, until we add the explicit initialization phase. Another use case is local execution for `LocalRelation`, there is no serialize and deserialize phase and thus we can't reset mutable states for it. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7535 from cloud-fan/init and squashes the following commits: 6c6f332 [Wenchen Fan] add test ef68ff4 [Wenchen Fan] fix comments 9eac85e [Wenchen Fan] move init code to interpreted class bb7d838 [Wenchen Fan] pulls out nondeterministic expressions into a project b4a4fc7 [Wenchen Fan] revert a refactor 86fee36 [Wenchen Fan] add initialization phase for nondeterministic expression
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala35
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala96
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala105
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala8
12 files changed, 254 insertions, 76 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e916887187..a723e92114 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
-import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._
import scala.collection.mutable.ArrayBuffer
@@ -78,7 +79,9 @@ class Analyzer(
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
HiveTypeCoercion.typeCoercionRules ++
- extendedResolutionRules : _*)
+ extendedResolutionRules : _*),
+ Batch("Nondeterministic", Once,
+ PullOutNondeterministic)
)
/**
@@ -910,6 +913,34 @@ class Analyzer(
Project(finalProjectList, withWindow)
}
}
+
+ /**
+ * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter,
+ * put them into an inner Project and finally project them away at the outer Project.
+ */
+ object PullOutNondeterministic extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p: Project => p
+ case f: Filter => f
+
+ // todo: It's hard to write a general rule to pull out nondeterministic expressions
+ // from LogicalPlan, currently we only do it for UnaryNode which has same output
+ // schema with its child.
+ case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
+ val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e =>
+ val ne = e match {
+ case n: NamedExpression => n
+ case _ => Alias(e, "_nondeterministic")()
+ }
+ new TreeNodeRef(e) -> ne
+ }.toMap
+ val newPlan = p.transformExpressions { case e =>
+ nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
+ }
+ val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
+ Project(p.output, newPlan.withNewChildren(newChild :: Nil))
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 81d473c113..a373714832 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -38,10 +37,10 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}
- def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
+ protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
- case e: Generator => true
- }).nonEmpty
+ case e: Generator => e
+ }).length > 1
}
def checkAnalysis(plan: LogicalPlan): Unit = {
@@ -137,13 +136,21 @@ trait CheckAnalysis {
s"""
|Failure when resolving conflicting references in Join:
|$plan
- |Conflicting attributes: ${conflictingAttributes.mkString(",")}
- |""".stripMargin)
+ |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+ |""".stripMargin)
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
+ case o if o.expressions.exists(!_.deterministic) &&
+ !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
+ failAnalysis(
+ s"""nondeterministic expressions are only allowed in Project or Filter, found:
+ | ${o.expressions.map(_.prettyString).mkString(",")}
+ |in operator ${operator.simpleString}
+ """.stripMargin)
+
case _ => // Analysis successful!
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 3f72e6e184..cb4c3f24b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -196,7 +196,26 @@ trait Unevaluable extends Expression {
* An expression that is nondeterministic.
*/
trait Nondeterministic extends Expression {
- override def deterministic: Boolean = false
+ final override def deterministic: Boolean = false
+ final override def foldable: Boolean = false
+
+ private[this] var initialized = false
+
+ final def initialize(): Unit = {
+ if (!initialized) {
+ initInternal()
+ initialized = true
+ }
+ }
+
+ protected def initInternal(): Unit
+
+ final override def eval(input: InternalRow = null): Any = {
+ require(initialized, "nondeterministic expression should be initialized before evaluate")
+ evalInternal(input)
+ }
+
+ protected def evalInternal(input: InternalRow): Any
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index fb873e7e99..c1ed9cf7ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+ expressions.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize()
+ case _ =>
+ })
+
// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null
@@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+ expressions.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize()
+ case _ =>
+ })
+
private[this] val exprArray = expressions.toArray
private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length)
def currentValue: InternalRow = mutableRow
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3f1bd2a925..5bfe1cad24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -30,6 +30,10 @@ object InterpretedPredicate {
create(BindReferences.bindReference(expression, inputSchema))
def create(expression: Expression): (InternalRow => Boolean) = {
+ expression.foreach {
+ case n: Nondeterministic => n.initialize()
+ case _ =>
+ }
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index aef24a5486..8f30519697 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
- * reset every time we serialize and deserialize it.
+ * reset every time we serialize and deserialize and initialize it.
*/
- @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+ @transient protected var rng: XORShiftRandom = _
+
+ override protected def initInternal(): Unit = {
+ rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+ }
override def nullable: Boolean = false
@@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic {
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
case class Rand(seed: Long) extends RDG {
- override def eval(input: InternalRow): Double = rng.nextDouble()
+ override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
def this() = this(Utils.random.nextLong())
@@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG {
/** Generate a random column with i.i.d. gaussian random distribution. */
case class Randn(seed: Long) extends RDG {
- override def eval(input: InternalRow): Double = rng.nextGaussian()
+ override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
def this() = this(Utils.random.nextLong())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 57a12820fa..8e1a236e29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override lazy val statistics: Statistics = {
- val limit = limitExpr.eval(null).asInstanceOf[Int]
+ val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
Statistics(sizeInBytes = sizeInBytes)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 7e67427237..ed645b618d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
+// todo: remove this and use AnalysisTest instead.
object AnalysisSuite {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
@@ -55,7 +52,7 @@ object AnalysisSuite {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
+ AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
val nestedRelation = LocalRelation(
@@ -81,8 +78,7 @@ object AnalysisSuite {
}
-class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
- import AnalysisSuite._
+class AnalysisSuite extends AnalysisTest {
test("union project *") {
val plan = (1 to 100)
@@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}
- assert(caseInsensitiveAnalyzer.execute(plan).resolved)
+ assertAnalysisSuccess(plan)
}
test("check project's resolved") {
@@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
}
test("analyze project") {
- assert(
- caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
- Project(testRelation.output, testRelation))
-
- assert(
- caseSensitiveAnalyzer.execute(
- Project(Seq(UnresolvedAttribute("TbL.a")),
- UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
- Project(testRelation.output, testRelation))
-
- val e = intercept[AnalysisException] {
- caseSensitiveAnalyze(
- Project(Seq(UnresolvedAttribute("tBl.a")),
- UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
- }
- assert(e.getMessage().toLowerCase.contains("cannot resolve"))
-
- assert(
- caseInsensitiveAnalyzer.execute(
- Project(Seq(UnresolvedAttribute("TbL.a")),
- UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
- Project(testRelation.output, testRelation))
-
- assert(
- caseInsensitiveAnalyzer.execute(
- Project(Seq(UnresolvedAttribute("tBl.a")),
- UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
- Project(testRelation.output, testRelation))
+ checkAnalysis(
+ Project(Seq(UnresolvedAttribute("a")), testRelation),
+ Project(testRelation.output, testRelation))
+
+ checkAnalysis(
+ Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+ Project(testRelation.output, testRelation))
+
+ assertAnalysisError(
+ Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+ Seq("cannot resolve"))
+
+ checkAnalysis(
+ Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+ Project(testRelation.output, testRelation),
+ caseSensitive = false)
+
+ checkAnalysis(
+ Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+ Project(testRelation.output, testRelation),
+ caseSensitive = false)
}
test("resolve relations") {
- val e = intercept[RuntimeException] {
- caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
- }
- assert(e.getMessage == "Table Not Found: tAbLe")
+ assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe"))
- assert(
- caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
+ checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation)
- assert(
- caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
+ checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false)
- assert(
- caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
+ checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false)
}
-
test("divide should be casted into fractional types") {
- val testRelation2 = LocalRelation(
- AttributeReference("a", StringType)(),
- AttributeReference("b", StringType)(),
- AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType(10, 2))(),
- AttributeReference("e", ShortType)())
-
val plan = caseInsensitiveAnalyzer.execute(
testRelation2.select(
'a / Literal(2) as 'div1,
@@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList
+ // StringType will be promoted into Double
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double
+ assert(pl(3).dataType == DoubleType)
assert(pl(4).dataType == DoubleType)
}
+
+ test("pull out nondeterministic expressions from unary LogicalPlan") {
+ val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
+ val projected = Alias(Rand(33), "_nondeterministic")()
+ val expected =
+ Project(testRelation.output,
+ RepartitionByExpression(Seq(projected.toAttribute),
+ Project(testRelation.output :+ projected, testRelation)))
+ checkAnalysis(plan, expected)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
new file mode 100644
index 0000000000..fdb4f28950
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.types._
+
+trait AnalysisTest extends PlanTest {
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
+
+ val testRelation2 = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", StringType)(),
+ AttributeReference("c", DoubleType)(),
+ AttributeReference("d", DecimalType(10, 2))(),
+ AttributeReference("e", ShortType)())
+
+ val nestedRelation = LocalRelation(
+ AttributeReference("top", StructType(
+ StructField("duplicateField", StringType) ::
+ StructField("duplicateField", StringType) ::
+ StructField("differentCase", StringType) ::
+ StructField("differentcase", StringType) :: Nil
+ ))())
+
+ val nestedRelation2 = LocalRelation(
+ AttributeReference("top", StructType(
+ StructField("aField", StringType) ::
+ StructField("bField", StringType) ::
+ StructField("cField", StringType) :: Nil
+ ))())
+
+ val listRelation = LocalRelation(
+ AttributeReference("list", ArrayType(IntegerType))())
+
+ val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
+ val caseSensitiveConf = new SimpleCatalystConf(true)
+ val caseInsensitiveConf = new SimpleCatalystConf(false)
+
+ val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
+ val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
+
+ caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+ caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+
+ new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
+ override val extendedResolutionRules = EliminateSubQueries :: Nil
+ } ->
+ new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) {
+ override val extendedResolutionRules = EliminateSubQueries :: Nil
+ }
+ }
+
+ protected def getAnalyzer(caseSensitive: Boolean) = {
+ if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
+ }
+
+ protected def checkAnalysis(
+ inputPlan: LogicalPlan,
+ expectedPlan: LogicalPlan,
+ caseSensitive: Boolean = true): Unit = {
+ val analyzer = getAnalyzer(caseSensitive)
+ val actualPlan = analyzer.execute(inputPlan)
+ analyzer.checkAnalysis(actualPlan)
+ comparePlans(actualPlan, expectedPlan)
+ }
+
+ protected def assertAnalysisSuccess(
+ inputPlan: LogicalPlan,
+ caseSensitive: Boolean = true): Unit = {
+ val analyzer = getAnalyzer(caseSensitive)
+ analyzer.checkAnalysis(analyzer.execute(inputPlan))
+ }
+
+ protected def assertAnalysisError(
+ inputPlan: LogicalPlan,
+ expectedErrors: Seq[String],
+ caseSensitive: Boolean = true): Unit = {
+ val analyzer = getAnalyzer(caseSensitive)
+ // todo: make sure we throw AnalysisException during analysis
+ val e = intercept[Exception] {
+ analyzer.checkAnalysis(analyzer.execute(inputPlan))
+ }
+ expectedErrors.forall(e.getMessage.contains)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 4930219aa6..852a8b235f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -64,6 +64,10 @@ trait ExpressionEvalHelper {
}
protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
+ expression.foreach {
+ case n: Nondeterministic => n.initialize()
+ case _ =>
+ }
expression.eval(inputRow)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
index 2645eb1854..eca36b3274 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
@@ -37,17 +37,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with
/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
- * we serialize and deserialize it.
+ * we serialize and deserialize and initialize it.
*/
- @transient private[this] var count: Long = 0L
+ @transient private[this] var count: Long = _
- @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33
+ @transient private[this] var partitionMask: Long = _
+
+ override protected def initInternal(): Unit = {
+ count = 0L
+ partitionMask = TaskContext.getPartitionId().toLong << 33
+ }
override def nullable: Boolean = false
override def dataType: DataType = LongType
- override def eval(input: InternalRow): Long = {
+ override protected def evalInternal(input: InternalRow): Long = {
val currentCount = count
count += 1
partitionMask + currentCount
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 53ddd47e3e..61ef079d89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -33,9 +33,13 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi
override def dataType: DataType = IntegerType
- @transient private lazy val partitionId = TaskContext.getPartitionId()
+ @transient private[this] var partitionId: Int = _
- override def eval(input: InternalRow): Int = partitionId
+ override protected def initInternal(): Unit = {
+ partitionId = TaskContext.getPartitionId()
+ }
+
+ override protected def evalInternal(input: InternalRow): Int = partitionId
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val idTerm = ctx.freshName("partitionId")