From 2c94d0f24a37fa079b56d534b0b0a4574209215b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 25 Jul 2015 12:10:02 -0700 Subject: [SPARK-9192][SQL] add initialization phase for nondeterministic expression Currently nondeterministic expression is broken without a explicit initialization phase. Let me take `MonotonicallyIncreasingID` as an example. This expression need a mutable state to remember how many times it has been evaluated, so we use `transient var count: Long` there. By being transient, the `count` will be reset to 0 and **only** to 0 when serialize and deserialize it, as deserialize transient variable will result to default value. There is *no way* to use another initial value for `count`, until we add the explicit initialization phase. Another use case is local execution for `LocalRelation`, there is no serialize and deserialize phase and thus we can't reset mutable states for it. Author: Wenchen Fan Closes #7535 from cloud-fan/init and squashes the following commits: 6c6f332 [Wenchen Fan] add test ef68ff4 [Wenchen Fan] fix comments 9eac85e [Wenchen Fan] move init code to interpreted class bb7d838 [Wenchen Fan] pulls out nondeterministic expressions into a project b4a4fc7 [Wenchen Fan] revert a refactor 86fee36 [Wenchen Fan] add initialization phase for nondeterministic expression --- .../spark/sql/catalyst/analysis/Analyzer.scala | 35 ++++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 19 ++-- .../sql/catalyst/expressions/Expression.scala | 21 ++++- .../sql/catalyst/expressions/Projection.scala | 10 ++ .../sql/catalyst/expressions/predicates.scala | 4 + .../spark/sql/catalyst/expressions/random.scala | 12 ++- .../catalyst/plans/logical/basicOperators.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 96 ++++++++----------- .../spark/sql/catalyst/analysis/AnalysisTest.scala | 105 +++++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 4 + .../expressions/MonotonicallyIncreasingID.scala | 13 ++- .../execution/expressions/SparkPartitionID.scala | 8 +- 12 files changed, 254 insertions(+), 76 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e916887187..a723e92114 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer @@ -78,7 +79,9 @@ class Analyzer( GlobalAggregates :: UnresolvedHavingClauseAttributes :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*) + extendedResolutionRules : _*), + Batch("Nondeterministic", Once, + PullOutNondeterministic) ) /** @@ -910,6 +913,34 @@ class Analyzer( Project(finalProjectList, withWindow) } } + + /** + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them away at the outer Project. + */ + object PullOutNondeterministic extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Project => p + case f: Filter => f + + // todo: It's hard to write a general rule to pull out nondeterministic expressions + // from LogicalPlan, currently we only do it for UnaryNode which has same output + // schema with its child. + case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne + }.toMap + val newPlan = p.transformExpressions { case e => + nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + } + val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 81d473c113..a373714832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -38,10 +37,10 @@ trait CheckAnalysis { throw new AnalysisException(msg) } - def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { - case e: Generator => true - }).nonEmpty + case e: Generator => e + }).length > 1 } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -137,13 +136,21 @@ trait CheckAnalysis { s""" |Failure when resolving conflicting references in Join: |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") + case o if o.expressions.exists(!_.deterministic) && + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + failAnalysis( + s"""nondeterministic expressions are only allowed in Project or Filter, found: + | ${o.expressions.map(_.prettyString).mkString(",")} + |in operator ${operator.simpleString} + """.stripMargin) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3f72e6e184..cb4c3f24b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -196,7 +196,26 @@ trait Unevaluable extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - override def deterministic: Boolean = false + final override def deterministic: Boolean = false + final override def foldable: Boolean = false + + private[this] var initialized = false + + final def initialize(): Unit = { + if (!initialized) { + initInternal() + initialized = true + } + } + + protected def initInternal(): Unit + + final override def eval(input: InternalRow = null): Any = { + require(initialized, "nondeterministic expression should be initialized before evaluate") + evalInternal(input) + } + + protected def evalInternal(input: InternalRow): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index fb873e7e99..c1ed9cf7ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3f1bd2a925..5bfe1cad24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -30,6 +30,10 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index aef24a5486..8f30519697 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is - * reset every time we serialize and deserialize it. + * reset every time we serialize and deserialize and initialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + @transient protected var rng: XORShiftRandom = _ + + override protected def initInternal(): Unit = { + rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + } override def nullable: Boolean = false @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ case class Rand(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextDouble() + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() def this() = this(Utils.random.nextLong()) @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG { /** Generate a random column with i.i.d. gaussian random distribution. */ case class Randn(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextGaussian() + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 57a12820fa..8e1a236e29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { - val limit = limitExpr.eval(null).asInstanceOf[Int] + val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum Statistics(sizeInBytes = sizeInBytes) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7e67427237..ed645b618d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +// todo: remove this and use AnalysisTest instead. object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -55,7 +52,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -81,8 +78,7 @@ object AnalysisSuite { } -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisSuite extends AnalysisTest { test("union project *") { val plan = (1 to 100) @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer.execute(plan).resolved) + assertAnalysisSuccess(plan) } test("check project's resolved") { @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { } test("analyze project") { - assert( - caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === - Project(testRelation.output, testRelation)) - - assert( - caseSensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - val e = intercept[AnalysisException] { - caseSensitiveAnalyze( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) - } - assert(e.getMessage().toLowerCase.contains("cannot resolve")) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) + checkAnalysis( + Project(Seq(UnresolvedAttribute("a")), testRelation), + Project(testRelation.output, testRelation)) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation)) + + assertAnalysisError( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Seq("cannot resolve")) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) } test("resolve relations") { - val e = intercept[RuntimeException] { - caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) - } - assert(e.getMessage == "Table Not Found: tAbLe") + assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) - assert( - caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) } - test("divide should be casted into fractional types") { - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, @@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } + + test("pull out nondeterministic expressions from unary LogicalPlan") { + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + RepartitionByExpression(Seq(projected.toAttribute), + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala new file mode 100644 index 0000000000..fdb4f28950 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.types._ + +trait AnalysisTest extends PlanTest { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + + val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { + val caseSensitiveConf = new SimpleCatalystConf(true) + val caseInsensitiveConf = new SimpleCatalystConf(false) + + val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) + val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) + + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } -> + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + } + + protected def getAnalyzer(caseSensitive: Boolean) = { + if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer + } + + protected def checkAnalysis( + inputPlan: LogicalPlan, + expectedPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + val actualPlan = analyzer.execute(inputPlan) + analyzer.checkAnalysis(actualPlan) + comparePlans(actualPlan, expectedPlan) + } + + protected def assertAnalysisSuccess( + inputPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + + protected def assertAnalysisError( + inputPlan: LogicalPlan, + expectedErrors: Seq[String], + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + // todo: make sure we throw AnalysisException during analysis + val e = intercept[Exception] { + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + expectedErrors.forall(e.getMessage.contains) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 4930219aa6..852a8b235f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -64,6 +64,10 @@ trait ExpressionEvalHelper { } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } expression.eval(inputRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 2645eb1854..eca36b3274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -37,17 +37,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with /** * Record ID within each partition. By being transient, count's value is reset to 0 every time - * we serialize and deserialize it. + * we serialize and deserialize and initialize it. */ - @transient private[this] var count: Long = 0L + @transient private[this] var count: Long = _ - @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + @transient private[this] var partitionMask: Long = _ + + override protected def initInternal(): Unit = { + count = 0L + partitionMask = TaskContext.getPartitionId().toLong << 33 + } override def nullable: Boolean = false override def dataType: DataType = LongType - override def eval(input: InternalRow): Long = { + override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 53ddd47e3e..61ef079d89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -33,9 +33,13 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId() + @transient private[this] var partitionId: Int = _ - override def eval(input: InternalRow): Int = partitionId + override protected def initInternal(): Unit = { + partitionId = TaskContext.getPartitionId() + } + + override protected def evalInternal(input: InternalRow): Int = partitionId override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId") -- cgit v1.2.3