/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql.catalyst.analysis import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. * Used for testing when all relations are already filled in and the analyzer needs only * to resolve attribute references. */ object SimpleAnalyzer extends Analyzer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) { override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} }, new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. * * Note this is thread local. * * @param defaultDatabase The default database used in the view resolution, this overrules the * current catalog database. * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the * depth of nested views. */ case class AnalysisContext( defaultDatabase: Option[String] = None, nestedViewDepth: Int = 0) object AnalysisContext { private val value = new ThreadLocal[AnalysisContext]() { override def initialValue: AnalysisContext = AnalysisContext() } def get: AnalysisContext = value.get() private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { val originContext = value.get() val context = AnalysisContext(defaultDatabase = database, nestedViewDepth = originContext.nestedViewDepth + 1) set(context) try f finally { set(originContext) } } } /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a * [[SessionCatalog]] and a [[FunctionRegistry]]. */ class Analyzer( catalog: SessionCatalog, conf: SQLConf, maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis { def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) /** * Override to provide additional rules for the "Resolution" batch. */ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil /** * Override to provide rules to do post-hoc resolution. Note that these rules will be executed * in an individual batch. This batch is to run right after the normal resolution batch and * execute its rules in one pass. */ val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, EliminateUnions, new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: ResolveCreateNamedStruct :: ResolveDeserializer :: ResolveNewInstance :: ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: ResolveSubquery :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, AliasViewChild(conf)), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), Batch("ResolveTimeZone", Once, ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) /** * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => resolved :+ name -> execute(substituteCTE(relation, resolved)) }) case other => other } def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = { plan transformDown { case u : UnresolvedRelation => cteRelations.find(x => resolver(x._1, u.tableIdentifier.table)) .map(_._2).getOrElse(u) case other => // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.plan, cteRelations)) } } } } /** * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = s"Window specification $windowName is not defined in the WINDOW clause." val windowSpecDefinition = windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) WindowExpression(c, windowSpecDefinition) } } } } /** * Replaces [[UnresolvedAlias]]s with concrete aliases. */ object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() case e: ExtractValue => Alias(e, toPrettySQL(e))() case e if optGenAliasFunc.isDefined => Alias(child, optGenAliasFunc.get.apply(e))() case e => Alias(e, toPrettySQL(e))() } } }.asInstanceOf[Seq[NamedExpression]] } private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.copy(aggregations = assignAliases(g.aggregations)) case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) if child.resolved && hasUnresolvedAlias(groupByExprs) => Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } } object ResolveGroupingAnalytics extends Rule[LogicalPlan] { /* * GROUP BY a, b, c WITH ROLLUP * is equivalent to * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( ) ). * Group Count: N + 1 (N is the number of group expressions) * * We need to get all of its subsets for the rule described above, the subset is * represented as sequence of expressions. */ def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toSeq /* * GROUP BY a, b, c WITH CUBE * is equivalent to * GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ). * Group Count: 2 ^ N (N is the number of group expressions) * * We need to get all of its subsets for a given GROUPBY expression, the subsets are * represented as sequence of expressions. */ def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match { case x :: xs => val initial = cubeExprs(xs) initial.map(x +: _) ++ initial case Nil => Seq(Seq.empty) } private def hasGroupingAttribute(expr: Expression): Boolean = { expr.collectFirst { case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u }.isDefined } private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g case g: GroupingID => g }.isDefined } private def replaceGroupingFunc( expr: Expression, groupByExprs: Seq[Expression], gid: Expression): Expression = { expr transform { case e: GroupingID => if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { Alias(gid, toPrettySQL(e))() } else { throw new AnalysisException( s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + s"grouping columns (${groupByExprs.mkString(",")})") } case e @ Grouping(col: Expression) => val idx = groupByExprs.indexOf(col) if (idx >= 0) { Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1)), ByteType), toPrettySQL(e))() } else { throw new AnalysisException(s"Column of grouping ($col) can't be found " + s"in grouping columns ${groupByExprs.mkString(",")}") } } } /* * Create new alias for all group by expressions for `Expand` operator. */ private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = { groupByExprs.map { case e: NamedExpression => Alias(e, e.name)() case other => Alias(other, other.toString)() } } /* * Construct [[Expand]] operator with grouping sets. */ private def constructExpand( selectedGroupByExprs: Seq[Seq[Expression]], child: LogicalPlan, groupByAliases: Seq[Alias], gid: Attribute): LogicalPlan = { // Change the nullability of group by aliases if necessary. For example, if we have // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we // should change the nullabilty of b to be TRUE. // TODO: For Cube/Rollup just set nullability to be `true`. val expandedAttributes = groupByAliases.map { alias => if (selectedGroupByExprs.exists(!_.contains(alias.child))) { alias.toAttribute.withNullability(true) } else { alias.toAttribute } } val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs => groupingSetExprs.map { expr => val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases")) // Map alias to expanded attribute. expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( alias.toAttribute) } } Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child) } /* * Construct new aggregate expressions by replacing grouping functions. */ private def constructAggregateExprs( groupByExprs: Seq[Expression], aggregations: Seq[NamedExpression], groupByAliases: Seq[Alias], groupingAttrs: Seq[Expression], gid: Attribute): Seq[NamedExpression] = aggregations.map { // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. val aggsBuffer = ArrayBuffer[Expression]() // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } replaceGroupingFunc(_, groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. case e: AggregateExpression => aggsBuffer += e e case e if isPartOfAggregation(e) => e case e => // Replace expression by expand output attribute. val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) if (index == -1) { e } else { groupingAttrs(index) } }.asInstanceOf[NamedExpression] } /* * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. */ private def constructAggregate( selectedGroupByExprs: Seq[Seq[Expression]], groupByExprs: Seq[Expression], aggregationExprs: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() // Expand works by setting grouping expressions to null as determined by the // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate // instead of the original value we need to create new aliases for all group by expressions // that will only be used for the intended purpose. val groupByAliases = constructGroupByAlias(groupByExprs) val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) val groupingAttrs = expand.output.drop(child.output.length) val aggregations = constructAggregateExprs( groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) Aggregate(groupingAttrs, aggregations, expand) } private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { plan.collectFirst { case a: Aggregate => // this Aggregate should have grouping id as the last grouping key. val gid = a.groupingExpressions.last if (!gid.isInstanceOf[AttributeReference] || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") } a.groupingExpressions.take(a.groupingExpressions.length - 1) }.getOrElse { failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") } } // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. case p if p.expressions.exists(hasGroupingAttribute) => failAnalysis( s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") // Ensure group by expressions and aggregate expressions have been resolved. case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) if (groupByExprs ++ aggregateExpressions).forall(_.resolved) => constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) if (groupByExprs ++ aggregateExpressions).forall(_.resolved) => constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child) // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child) // We should make sure all expressions in condition have been resolved. case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => val groupingExprs = findGroupingExprs(child) // The unresolved grouping id will be resolved by ResolveMissingReferences val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) f.copy(condition = newCond) // We should make sure all [[SortOrder]]s have been resolved. case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) && order.forall(_.resolved) => val groupingExprs = findGroupingExprs(child) val gid = VirtualColumn.groupingIdAttribute // The unresolved grouping id will be resolved by ResolveMissingReferences val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) s.copy(order = newOrder) } } object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 def outputName(value: Literal, aggregate: Expression): String = { val utf8Value = Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) val stringValue: String = Option(utf8Value).map(_.toString).getOrElse("null") if (singleAgg) { stringValue } else { val suffix = aggregate match { case n: NamedExpression => n.name case _ => toPrettySQL(aggregate) } stringValue + "_" + suffix } } if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) val namedPivotCol = pivotColumn match { case n: NamedExpression => n case _ => Alias(pivotColumn, "__pivot_col")() } val bigGroup = groupByExprs :+ namedPivotCol val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) val pivotAggs = namedAggExps.map { a => Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } val groupByExprsAttr = groupByExprs.map(_.toAttribute) val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) val pivotAggAttribute = pivotAggs.map(_.toAttribute) val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { // Assumption is the aggregate function ignores nulls. This is true for all current // AggregateFunction's with the exception of First and Last in their default mode // (which we handle) and possibly some Hive UDAF's. case First(expr, _) => First(ifExpr(expr), Literal(true)) case Last(expr, _) => Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) }.transform { // We are duplicating aggregates that are now computing a different value for each // pivot value. // TODO: Don't construct the physical container until after analysis. case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( s"Aggregate expression required for pivot, found '$aggregate'") } Alias(filteredAggregate, outputName(value, aggregate))() } } Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } } } /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ object ResolveRelations extends Rule[LogicalPlan] { // If the unresolved relation is running directly on files, we just return the original // UnresolvedRelation, the plan will get resolved later. Else we look up the table from catalog // and change the default database name(in AnalysisContext) if it is a view. // We usually look up a table from the default database if the table identifier has an empty // database part, for a view the default database should be the currentDb when the view was // created. When the case comes to resolving a nested view, the view may have different default // database with that the referenced view has, so we need to use // `AnalysisContext.defaultDatabase` to track the current default database. // When the relation we resolve is a view, we fetch the view.desc(which is a CatalogTable), and // then set the value of `CatalogTable.viewDefaultDatabase` to // `AnalysisContext.defaultDatabase`, we look up the relations that the view references using // the default database. // For example: // |- view1 (defaultDatabase = db1) // |- operator // |- table2 (defaultDatabase = db1) // |- view2 (defaultDatabase = db2) // |- view3 (defaultDatabase = db3) // |- view4 (defaultDatabase = db4) // In this case, the view `view1` is a nested view, it directly references `table2`, `view2` // and `view4`, the view `view2` references `view3`. On resolving the table, we look up the // relations `table2`, `view2`, `view4` using the default database `db1`, and look up the // relation `view3` using the default database `db2`. // // Note this is compatible with the views defined by older versions of Spark(before 2.2), which // have empty defaultDatabase and all the relations in viewText have database part defined. def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match { case u: UnresolvedRelation if !isRunningDirectlyOnFiles(u.tableIdentifier) => val defaultDatabase = AnalysisContext.get.defaultDatabase val relation = lookupTableFromCatalog(u, defaultDatabase) resolveRelation(relation) // The view's child should be a logical plan parsed from the `desc.viewText`, the variable // `viewText` should be defined, or else we throw an error on the generation of the View // operator. case view @ View(desc, _, child) if !child.resolved => // Resolve all the UnresolvedRelations and Views in the child. val newChild = AnalysisContext.withAnalysisContext(desc.viewDefaultDatabase) { if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) { view.failAnalysis(s"The depth of view ${view.desc.identifier} exceeds the maximum " + s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " + "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } execute(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => val newChild = resolveRelation(view) p.copy(child = newChild) case _ => plan } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") case other => i.copy(table = other) } case u: UnresolvedRelation => resolveRelation(u) } // Look up the table with the given name from catalog. The database we used is decided by the // precedence: // 1. Use the database part of the table identifier, if it is defined; // 2. Use defaultDatabase, if it is defined(In this case, no temporary objects can be used, // and the default database is only used to look up a view); // 3. Use the currentDb of the SessionCatalog. private def lookupTableFromCatalog( u: UnresolvedRelation, defaultDatabase: Option[String] = None): LogicalPlan = { val tableIdentWithDb = u.tableIdentifier.copy( database = u.tableIdentifier.database.orElse(defaultDatabase)) try { catalog.lookupRelation(tableIdentWithDb) } catch { case _: NoSuchTableException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + s"database ${e.db} doesn't exsits.") } } // If the database part is specified, and we support running SQL directly on files, and // it's not a temporary view, and the table does not exist, then let's just return the // original UnresolvedRelation. It is possible we are matching a query like "select * // from parquet.`/path/to/query`". The plan will get resolved in the rule `ResolveDataSource`. // Note that we are testing (!db_exists || !table_exists) because the catalog throws // an exception from tableExists if the database does not exist. private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = { table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table)) } } /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { /** * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") right.collect { // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.newInstance() (oldVersion, newVersion) case oldVersion: SerializeFromObject if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))) // Handle projects that create conflicting aliases. case oldVersion @ Project(projectList, _) if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion: Generate if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) case oldVersion @ Window(windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) } // Only handle first case, others will be fixed on the next pass. .headOption match { case None => /* * No result implies that there is a logical plan node that produces new references * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) val newRight = right transformUp { case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { case a: Attribute => dedupAttr(a, attributeRewrites) case s: SubqueryExpression => s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } newRight } } private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { attrMap.get(attr).getOrElse(attr).withQualifier(attr.qualifier) } /** * The outer plan may have been de-duplicated and the function below updates the * outer references to refer to the de-duplicated attributes. * * For example (SQL): * {{{ * SELECT * FROM t1 * INTERSECT * SELECT * FROM t1 * WHERE EXISTS (SELECT 1 * FROM t2 * WHERE t1.c1 = t2.c1) * }}} * Plan before resolveReference rule. * 'Intersect * :- Project [c1#245, c2#246] * : +- SubqueryAlias t1 * : +- Relation[c1#245,c2#246] parquet * +- 'Project [*] * +- Filter exists#257 [c1#245] * : +- Project [1 AS 1#258] * : +- Filter (outer(c1#245) = c1#251) * : +- SubqueryAlias t2 * : +- Relation[c1#251,c2#252] parquet * +- SubqueryAlias t1 * +- Relation[c1#245,c2#246] parquet * Plan after the resolveReference rule. * Intersect * :- Project [c1#245, c2#246] * : +- SubqueryAlias t1 * : +- Relation[c1#245,c2#246] parquet * +- Project [c1#259, c2#260] * +- Filter exists#257 [c1#259] * : +- Project [1 AS 1#258] * : +- Filter (outer(c1#259) = c1#251) => Updated * : +- SubqueryAlias t2 * : +- Relation[c1#251,c2#252] parquet * +- SubqueryAlias t1 * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. */ private def dedupOuterReferencesInSubquery( plan: LogicalPlan, attrMap: AttributeMap[Attribute]): LogicalPlan = { plan transformDown { case currentFragment => currentFragment transformExpressions { case OuterReference(a: Attribute) => OuterReference(dedupAttr(a, attrMap)) case s: SubqueryExpression => s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) } } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. case p: Project if containsStar(p.projectList) => p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { failAnalysis( "Star (*) is not allowed in select list when GROUP BY ordinal position is used") } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( input = t.input.flatMap { case s: Star => s.expand(t.child, resolver) case o => o :: Nil } ) case g: Generate if containsStar(g.generator.children) => failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) case i @ Intersect(left, right) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case i @ Except(left, right) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => val newOrdering = ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g case g @ Generate(generator, join, outer, qualifier, output, child) => val newG = resolveExpression(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } // Skips plan which contains deserializer expressions, as they should be resolved by another // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => ExtractValue(child, fieldExpr, resolver) } } def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) case other => other } } def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } /** * Build a project list for Project/Aggregate and expand the star if possible */ private def buildExpandedProjectList( exprs: Seq[NamedExpression], child: LogicalPlan): Seq[NamedExpression] = { exprs.flatMap { // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") case s: Star => s.expand(child, resolver) // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b case UnresolvedAlias(s: Star, _) => s.expand(child, resolver) case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil case o => o :: Nil }.map(_.asInstanceOf[NamedExpression]) } /** * Returns true if `exprs` contains a [[Star]]. */ def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) /** * Expands the matching attribute.*'s in `child`'s output. */ def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { expr.transformUp { case f1: UnresolvedFunction if containsStar(f1.children) => f1.copy(children = f1.children.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil }) case c: CreateNamedStruct if containsStar(c.valExprs) => val newChildren = c.children.grouped(2).flatMap { case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children case kv => kv } c.copy(children = newChildren.toList ) case c: CreateArray if containsStar(c.children) => c.copy(children = c.children.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil }) case p: Murmur3Hash if containsStar(p.children) => p.copy(children = p.children.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil }) // count(*) has been replaced by count(1) case o if containsStar(o.children) => failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") } } } private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) } protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { // Resolve expression in one round. // If throws == false or the desired attribute doesn't exist // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. // Else, throw exception. try { expr transformUp { case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } } catch { case a: AnalysisException if !throws => expr } } /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the * select list. This support is introduced in Spark 2.0. * * - When the sort references or group by expressions are not integer but foldable expressions, * just ignore them. * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position * numbers too. * * Before the release of Spark 2.0, the literals in order/sort by and group by clauses * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. case s @ Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + s"(valid range is [1, ${child.output.size}])") } case o => o } Sort(newOrders, global, child) // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => aggs(index - 1) match { case e if ResolveAggregateFunctions.containsAggregate(e) => ordinal.failAnalysis( s"GROUP BY position $index is an aggregate function, and " + "aggregate functions are not allowed in GROUP BY") case o => o } case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + s"(valid range is [1, ${aggs.size}])") case o => o } Aggregate(newGroups, aggs, child) } } /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. * * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa case s @ Sort(order, _, child) if child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) val missingAttrs = requiredAttrs -- child.outputSet if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away after the sort. Project(child.output, Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) } else if (newOrder != order) { s.copy(order = newOrder) } else { s } } catch { // Attempting to resolve it might fail. When this happens, return the original plan. // Users will see an AnalysisException for resolution failure of missing attributes // in Sort case ae: AnalysisException => s } case f @ Filter(cond, child) if child.resolved => try { val newCond = resolveExpressionRecursively(cond, child) val requiredAttrs = newCond.references.filter(_.resolved) val missingAttrs = requiredAttrs -- child.outputSet if (missingAttrs.nonEmpty) { // Add missing attributes and then project them away. Project(child.output, Filter(newCond, addMissingAttr(child, missingAttrs))) } else if (newCond != cond) { f.copy(condition = newCond) } else { f } } catch { // Attempting to resolve it might fail. When this happens, return the original plan. // Users will see an AnalysisException for resolution failure of missing attributes case ae: AnalysisException => f } } /** * Add the missing attributes into projectList of Project/Window or aggregateExpressions of * Aggregate. */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { return plan } plan match { case p: Project => val missing = missingAttrs -- p.child.outputSet Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) case a: Aggregate => // all the missing attributes should be grouping expressions // TODO: push down AggregateExpression missingAttrs.foreach { attr => if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") } } val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs a.copy(aggregateExpressions = newAggregateExpressions) case g: Generate => // If join is false, we will convert it to true for getting from the child the missing // attributes that its child might have or could have. val missing = missingAttrs -- g.child.outputSet g.copy(join = true, child = addMissingAttr(g.child, missing)) case d: Distinct => throw new AnalysisException(s"Can't add $missingAttrs to $d") case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) case other => throw new AnalysisException(s"Can't add $missingAttrs to $other") } } /** * Resolve the expression on a specified logical plan and it's child (recursively), until * the expression is resolved or meet a non-unary node or Subquery. */ @tailrec private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { val resolved = resolveExpression(expr, plan) if (resolved.resolved) { resolved } else { plan match { case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => resolveExpressionRecursively(resolved, u.child) case other => resolved } } } } /** * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It * only performs simple existence check according to the function identifier to quickly identify * undefined functions without triggering relation resolution, which may incur potentially * expensive partition/schema discovery process in some cases. * * @see [[ResolveFunctions]] * @see https://issues.apache.org/jira/browse/SPARK-19737 */ object LookupFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { case f: UnresolvedFunction if !catalog.functionExists(f.name) => withPosition(f) { throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) } } } /** * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. case u @ UnresolvedGenerator(name, children) => withPosition(u) { catalog.lookupFunction(name, children) match { case generator: Generator => generator case other => failAnalysis(s"$name is expected to be a generator. However, " + s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") } } case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { catalog.lookupFunction(funcId, children) match { // DISTINCT is not meaningful for a Max or a Min. case max: Max if isDistinct => AggregateExpression(max, Complete, isDistinct = false) case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. case wf: AggregateWindowFunction => wf // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. case other => other } } } } } /** * This rule resolves and rewrites subqueries inside expressions. * * Note: CTEs are handled in CTESubstitution. */ object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { /** * Resolve the correlated expressions in a subquery by using the an outer plans' references. All * resolved outer references are wrapped in an [[OuterReference]] */ private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan transformDown { case q: LogicalPlan if q.childrenResolved && !q.resolved => q transformExpressions { case u @ UnresolvedAttribute(nameParts) => withPosition(u) { try { outer.resolve(nameParts, resolver) match { case Some(outerAttr) => OuterReference(outerAttr) case None => u } } catch { case _: AnalysisException => u } } } } } /** * Validates to make sure the outer references appearing inside the subquery * are legal. This function also returns the list of expressions * that contain outer references. These outer references would be kept as children * of subquery expressions by the caller of this function. */ private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { val outerReferences = ArrayBuffer.empty[Expression] // Make sure a plan's subtree does not contain outer references def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { if (hasOuterReferences(p)) { failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { if (p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + s"clauses:\n$p") } } // SPARK-17348: A potential incorrect result case. // When a correlated predicate is a non-equality predicate, // certain operators are not permitted from the operator // hosting the correlated predicate up to the operator on the outer table. // Otherwise, the pull up of the correlated predicate // will generate a plan with a different semantics // which could return incorrect result. // Currently we check for Aggregate and Window operators // // Below shows an example of a Logical Plan during Analyzer phase that // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] // through the Aggregate (or Window) operator could alter the result of // the Aggregate. // // Project [c1#76] // +- Project [c1#87, c2#88] // : (Aggregate or Window operator) // : +- Filter [outer(c2#77) >= c2#88)] // : +- SubqueryAlias t2, `t2` // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] // : +- LocalRelation [_1#84, _2#85] // +- SubqueryAlias t1, `t1` // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] // +- LocalRelation [_1#73, _2#74] def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { if (found) { // Report a non-supported case as an exception failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") } } var foundNonEqualCorrelatedPred : Boolean = false // Simplify the predicates before validating any unsupported correlation patterns // in the plan. BooleanSimplification(sub).foreachUp { // Whitelist operators allowed in a correlated subquery // There are 4 categories: // 1. Operators that are allowed anywhere in a correlated subquery, and, // by definition of the operators, they either do not contain // any columns or cannot host outer references. // 2. Operators that are allowed anywhere in a correlated subquery // so long as they do not host outer references. // 3. Operators that need special handlings. These operators are // Project, Filter, Join, Aggregate, and Generate. // // Any operators that are not in the above list are allowed // in a correlated subquery only if they are not on a correlation path. // In other word, these operators are allowed only under a correlation point. // // A correlation path is defined as the sub-tree of all the operators that // are on the path from the operator hosting the correlated expressions // up to the operator producing the correlated values. // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. // so long as they do not host outer references in the operators. case s: Sort => failOnOuterReference(s) case r: RepartitionByExpression => failOnOuterReference(r) // Category 3: // Filter is one of the two operators allowed to host correlated expressions. // The other operator is Join. Filter can be anywhere in a correlated subquery. case f: Filter => // Find all predicates with an outer reference. val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) // Find any non-equality correlated predicates foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { case _: EqualTo | _: EqualNullSafe => false case _ => true } // The aggregate expressions are treated in a special way by getOuterReferences. If the // aggregate expression contains only outer reference attributes then the entire aggregate // expression is isolated as an OuterReference. // i.e min(OuterReference(b)) => OuterReference(min(b)) outerReferences ++= getOuterReferences(correlated) // Project cannot host any correlated expressions // but can be anywhere in a correlated subquery. case p: Project => failOnOuterReference(p) // Aggregate cannot host any correlated expressions // It can be on a correlation path if the correlation contains // only equality correlated predicates. // It cannot be on a correlation path if the correlation has // non-equality correlated predicates. case a: Aggregate => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) // Join can host correlated expressions. case j @ Join(left, right, joinType, _) => joinType match { // Inner join, like Filter, can be anywhere. case _: InnerLike => failOnOuterReference(j) // Left outer join's right operand cannot be on a correlation path. // LeftAnti and ExistenceJoin are special cases of LeftOuter. // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame // so it should not show up here in Analysis phase. This is just a safety net. // // LeftSemi does not allow output from the right operand. // Any correlated references in the subplan // of the right operand cannot be pulled up. case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => failOnOuterReference(j) failOnOuterReferenceInSubTree(right) // Likewise, Right outer join's left operand cannot be on a correlation path. case RightOuter => failOnOuterReference(j) failOnOuterReferenceInSubTree(left) // Any other join types not explicitly listed above, // including Full outer join, are treated as Category 4. case _ => failOnOuterReferenceInSubTree(j) } // Generator with join=true, i.e., expressed with // LATERAL VIEW [OUTER], similar to inner join, // allows to have correlation under it // but must not host any outer references. // Note: // Generator with join=false is treated as Category 4. case g: Generate if g.join => failOnOuterReference(g) // Category 4: Any other operators not in the above 3 categories // cannot be on a correlation path, that is they are allowed only // under a correlation point but they and their descendant operators // are not allowed to have any correlated expressions. case p => failOnOuterReferenceInSubTree(p) } outerReferences } /** * Resolves the subquery. The subquery is resolved using its outer plans. This method * will resolve the subquery by alternating between the regular analyzer and by applying the * resolveOuterReferences rule. * * Outer references from the correlated predicates are updated as children of * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, plans: Seq[LogicalPlan], requiredColumns: Int = 0)( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null var current = e.plan do { // Try to resolve the subquery plan using the regular analyzer. previous = current current = execute(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator val afterResolve = current while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) { current = resolveOuterReferences(current, i.next()) } } while (!current.resolved && !current.fastEquals(previous)) // Step 2: If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only // needed for Scalar and IN subqueries. if (requiredColumns > 0 && requiredColumns != current.output.size) { failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + s"does not match the required number of columns ($requiredColumns)") } // Validate the outer reference and record the outer references as children of // subquery expression. f(current, checkAndGetOuterReferences(current)) } else { e.withNewPlan(current) } } /** * Resolves the subquery. Apart of resolving the subquery and outer references (if any) * in the subquery plan, the children of subquery expression are updated to record the * outer references. This is needed to make sure * (1) The column(s) referred from the outer query are not pruned from the plan during * optimization. * (2) Any aggregate expression(s) that reference outer attributes are pushed down to * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => // Get the left hand side expressions. val expressions = value match { case cns : CreateNamedStruct => cns.valExprs case expr => Seq(expr) } val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) In(value, Seq(expr)) } } /** * Resolve and rewrite all subqueries in an operator tree.. */ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => resolveSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. case q: UnaryNode if q.childrenResolved => resolveSubQueries(q, q.children) } } /** * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } def containsAggregates(exprs: Seq[Expression]): Boolean = { // Collect all Windowed Aggregate Expressions. val windowedAggExprs = exprs.flatMap { expr => expr.collect { case WindowExpression(ae: AggregateExpression, _) => ae } }.toSet // Find the first Aggregate Expression that is not Windowed. exprs.exists(_.collectFirst { case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae }.isDefined) } } /** * This rule finds aggregate expressions that are not in an aggregate operator. For example, * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { val aggregatedCondition = Aggregate( grouping, Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] .aggregateExpressions.head // If resolution was successful and we see the filter has an aggregate in it, add it to // the original aggregate operator. if (resolvedOperator.resolved) { // Try to replace all aggregate expressions in the filter by an alias. val aggregateExpressions = ArrayBuffer.empty[NamedExpression] val transformedAggregateFilter = resolvedAggregateFilter.transform { case ae: AggregateExpression => val alias = Alias(ae, ae.toString)() aggregateExpressions += alias alias.toAttribute // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && !ResolveGroupingAnalytics.hasGroupingFunction(e) && !aggregate.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne ne.toAttribute case _ => val alias = Alias(e, e.toString)() aggregateExpressions += alias alias.toAttribute } } // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { Project(aggregate.output, Filter(transformedAggregateFilter, aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) } else { filter } } else { filter } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. case ae: AnalysisException => filter } case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] // If we pass the analysis check, then the ordering expressions should only reference to // aggregate expressions or grouping expressions, and it's safe to push them down to // Aggregate. checkAnalysis(resolvedAggregate) val originalAggExprs = aggregate.aggregateExpressions.map( CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) // If the ordering expression is same with original aggregate expression, we don't need // to push down this ordering expression and can reference the original aggregate // expression instead. val needsPushDown = ArrayBuffer.empty[NamedExpression] val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { case (evaluated, order) => val index = originalAggExprs.indexWhere { case Alias(child, _) => child semanticEquals evaluated.child case other => other semanticEquals evaluated.child } if (index == -1) { needsPushDown += evaluated order.copy(child = evaluated.toAttribute) } else { order.copy(child = originalAggExprs(index).toAttribute) } } val sortOrdersMap = unresolvedSortOrders .map(new TreeNodeRef(_)) .zip(evaluatedOrderings) .toMap val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s)) // Since we don't rely on sort.resolved as the stop condition for this rule, // we need to check this and prevent applying this rule multiple times if (sortOrder == finalSortOrders) { sort } else { Project(aggregate.output, Sort(finalSortOrders, global, aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. case ae: AnalysisException => sort } } def containsAggregate(condition: Expression): Boolean = { condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } /** * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] * operator under [[Project]]. * * This rule will throw [[AnalysisException]] for following cases: * 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl` * 2. more than one [[Generator]] is found in projectList, * e.g. `SELECT explode(list), explode(list) FROM tbl` * 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]], * e.g. `SELECT * FROM tbl SORT BY explode(list)` */ object ExtractGenerator extends Rule[LogicalPlan] { private def hasGenerator(expr: Expression): Boolean = { expr.find(_.isInstanceOf[Generator]).isDefined } private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { case UnresolvedAlias(_: Generator, _) => false case Alias(_: Generator, _) => false case MultiAlias(_: Generator, _) => false case other => hasGenerator(other) } private def trimAlias(expr: NamedExpression): Expression = expr match { case UnresolvedAlias(child, _) => child case Alias(child, _) => child case MultiAlias(child, _) => child case _ => expr } private object AliasedGenerator { /** * Extracts a [[Generator]] expression, any names assigned by aliases to the outputs * and the outer flag. The outer flag is used when joining the generator output. * @param e the [[Expression]] * @return (the [[Generator]], seq of output names, outer flag) */ def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match { case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true)) case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some(g, names, true) case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false)) case MultiAlias(g: Generator, names) if g.resolved => Some(g, names, false) case _ => None } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) case Project(projectList, _) if projectList.count(hasGenerator) > 1 => val generators = projectList.filter(hasGenerator).map(trimAlias) throw new AnalysisException("Only one generator allowed per select clause but found " + generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) case p @ Project(projectList, child) => // Holds the resolved generator, if one exists in the project list. var resolvedGenerator: Generate = null val newProjectList = projectList.flatMap { case AliasedGenerator(generator, names, outer) if generator.childrenResolved => // It's a sanity check, this should not happen as the previous case will throw // exception earlier. assert(resolvedGenerator == null, "More than one generator found in SELECT.") resolvedGenerator = Generate( generator, join = projectList.size > 1, // Only join if there are other expressions in SELECT. outer = outer, qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) resolvedGenerator.generatorOutput case other => other :: Nil } if (resolvedGenerator != null) { Project(newProjectList, resolvedGenerator) } else { p } case g: Generate => g case p if p.expressions.exists(hasGenerator) => throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + "got: " + p.simpleString) } } /** * Rewrites table generating expressions that either need one or more of the following in order * to be resolved: * - concrete attribute references for their output. * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). * * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) } /** * Construct the output attributes for a [[Generator]], given a list of names. If the list of * names is empty names are assigned from field names in generator. */ private[analysis] def makeGeneratorOutput( generator: Generator, names: Seq[String]): Seq[Attribute] = { val elementAttrs = generator.elementSchema.toAttributes if (names.length == elementAttrs.length) { names.zip(elementAttrs).map { case (name, attr) => attr.withName(name) } } else if (names.isEmpty) { elementAttrs } else { failAnalysis( "The number of aliases supplied in the AS clause does not match the number of columns " + s"output by the UDTF expected ${elementAttrs.size} aliases but got " + s"${names.mkString(",")} ") } } } /** * Fixes nullability of Attributes in a resolved LogicalPlan by using the nullability of * corresponding Attributes of its children output Attributes. This step is needed because * users can use a resolved AttributeReference in the Dataset API and outer joins * can change the nullability of an AttribtueReference. Without the fix, a nullable column's * nullable field can be actually set as non-nullable, which cause illegal optimization * (e.g., NULL propagation) and wrong answers. * See SPARK-13484 and SPARK-13801 for the concrete queries of this case. */ object FixNullability extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p if !p.resolved => p // Skip unresolved nodes. case p: LogicalPlan if p.resolved => val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { case (exprId, attributes) => // If there are multiple Attributes having the same ExprId, we need to resolve // the conflict of nullable field. We do not really expect this happen. val nullable = attributes.exists(_.nullable) attributes.map(attr => attr.withNullability(nullable)) }.toSeq // At here, we create an AttributeMap that only compare the exprId for the lookup // operation. So, we can find the corresponding input attribute's nullability. val attributeMap = AttributeMap[Attribute](childrenOutput.map(attr => attr -> attr)) // For an Attribute used by the current LogicalPlan, if it is from its children, // we fix the nullable field by using the nullability setting of the corresponding // output Attribute from the children. p.transformExpressions { case attr: Attribute if attributeMap.contains(attr) => attr.withNullability(attributeMap(attr).nullable) } } } /** * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]] * operators for every distinct [[WindowSpecDefinition]]. * * This rule handles three cases: * - A [[Project]] having [[WindowExpression]]s in its projectList; * - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions. * - A [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING * clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions. * Note: If there is a GROUP BY clause in the query, aggregations and corresponding * filters (expressions in the HAVING clause) should be evaluated before any * [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be * evaluated after all [[WindowExpression]]s. * * For every case, the transformation works as follows: * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for * all regular expressions. * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s. * 3. For every distinct [[WindowSpecDefinition]], creates a [[Window]] operator and inserts * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = projectList.exists(hasWindowFunction) private def hasWindowFunction(expr: NamedExpression): Boolean = { expr.find { case window: WindowExpression => true case _ => false }.isDefined } /** * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and * other regular expressions that do not contain any window expression. For example, for * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in * the window expression as attribute references. So, the first returned value will be * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. * * @return (seq of expressions containing at lease one window expressions, * seq of non-window expressions) */ private def extract( expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { // First, we partition the input expressions to two part. For the first part, // every expression in it contain at least one WindowExpression. // Expressions in the second part do not have any WindowExpression. val (expressionsWithWindowFunctions, regularExpressions) = expressions.partition(hasWindowFunction) // Then, we need to extract those regular expressions used in the WindowExpression. // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), // we need to make sure that col1 to col5 are all projected from the child of the Window // operator. val extractedExprBuffer = new ArrayBuffer[NamedExpression]() def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => // If a named expression is not in regularExpressions, add it to // extractedExprBuffer and replace it with an AttributeReference. val missingExpr = AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) if (missingExpr.nonEmpty) { extractedExprBuffer += ne } // alias will be cleaned in the rule CleanupAliases ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with // an internal column name, e.g. "_w0"). val withName = Alias(e, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute } // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). case wf: WindowFunction => val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) // Extract Windowed AggregateExpression case we @ WindowExpression( ae @ AggregateExpression(function, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] val newAgg = ae.copy(aggregateFunction = newFunction) seenWindowAggregates += newAgg WindowExpression(newAgg, spec) // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute // Extracts other attributes case attr: Attribute => extractExpr(attr) }.asInstanceOf[NamedExpression] } (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) } // end of extract /** * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ private def addWindow( expressionsWithWindowFunctions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions // and put those extracted WindowExpressions to extractedWindowExprBuffer. // This step is needed because it is possible that an expression contains multiple // WindowExpressions with different Window Specs. // After extracting WindowExpressions, we need to construct a project list to generate // expressionsWithWindowFunctions based on extractedWindowExprBuffer. // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". // Then, the projectList will be [_we0/_we1]. val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { // We need to use transformDown because we want to trigger // "case alias @ Alias(window: WindowExpression, _)" first. _.transformDown { case alias @ Alias(window: WindowExpression, _) => // If a WindowExpression has an assigned alias, just use it. extractedWindowExprBuffer += alias alias.toAttribute case window: WindowExpression => // If there is no alias assigned to the WindowExpressions. We create an // internal column. val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() extractedWindowExprBuffer += withName withName.toAttribute }.asInstanceOf[NamedExpression] } // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs. val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec }.distinct // We do a final check and see if we only have a single Window Spec defined in an // expressions. if (distinctWindowSpec.isEmpty) { failAnalysis(s"$expr does not have any WindowExpression.") } else if (distinctWindowSpec.length > 1) { // newExpressionsWithWindowFunctions only have expressions with a single // WindowExpression. If we reach here, we have a bug. failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + s"Please file a bug report with this error message, stack trace, and the query.") } else { val spec = distinctWindowSpec.head (spec.partitionSpec, spec.orderSpec) } }.toSeq // Third, we aggregate them by adding each Window operator for each Window Spec and then // setting this to the child of the next Window operator. val windowOps = groupedWindowExpressions.foldLeft(child) { case (last, ((partitionSpec, orderSpec), windowExpressions)) => Window(windowExpressions, partitionSpec, orderSpec, last) } // Finally, we create a Project to output windowOps's output // newExpressionsWithWindowFunctions. Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) } // end of addWindow // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) // Add a Filter operator for conditions in the Having clause. val withFilter = Filter(condition, withAggregate) val withWindow = addWindow(windowExpressions, withFilter) // Finally, generate output columns according to the original projectList. val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) case p: LogicalPlan if !p.childrenResolved => p // Aggregate without Having clause. case a @ Aggregate(groupingExprs, aggregateExprs, child) if hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) // Add Window operators. val withWindow = addWindow(windowExpressions, withAggregate) // Finally, generate output columns according to the original projectList. val finalProjectList = aggregateExprs.map(_.toAttribute) Project(finalProjectList, withWindow) // We only extract Window Expressions after all expressions of the Project // have been resolved. case p @ Project(projectList, child) if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) => val (windowExpressions, regularExpressions) = extract(projectList) // We add a project to get all needed expressions for window expressions from the child // of the original Project operator. val withProject = Project(regularExpressions, child) // Add Window operators. val withWindow = addWindow(windowExpressions, withProject) // Finally, generate output columns according to the original projectList. val finalProjectList = projectList.map(_.toAttribute) Project(finalProjectList, withWindow) } } /** * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) a.transformExpressions { case e => nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) }.copy(child = newChild) // todo: It's hard to write a general rule to pull out nondeterministic expressions // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => val nondeterToAttr = getNondeterToAttr(p.expressions) val newPlan = p.transformExpressions { case e => nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) } val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { exprs.filterNot(_.deterministic).flatMap { expr => val leafNondeterministic = expr.collect { case n: Nondeterministic => n } leafNondeterministic.distinct.map { e => val ne = e match { case n: NamedExpression => n case _ => Alias(e, "_nondeterministic")(isGenerated = true) } e -> ne } }.toMap } } /** * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the * null check. When user defines a UDF with primitive parameters, there is no way to tell if the * primitive parameter is null or not, so here we assume the primitive input is null-propagatable * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { case udf @ ScalaUDF(func, _, inputs, _, _) => val parameterTypes = ScalaReflection.getParameterTypes(func) assert(parameterTypes.length == inputs.length) val inputsNullCheck = parameterTypes.zip(inputs) // TODO: skip null handling for not-nullable primitive inputs after we can completely // trust the `nullable` information. // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable } .filter { case (cls, _) => cls.isPrimitive } .map { case (_, expr) => IsNull(expr) } .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf) } } } /** * Check and add proper window frames for all window functions. */ object ResolveWindowFrame extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != UnspecifiedFrame && wf.frame != f => failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => WindowExpression(wf, s.copy(frameSpecification = wf.frame)) case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if e.resolved => val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) we.copy(windowSpec = s.copy(frameSpecification = frame)) } } } /** * Check and add order to [[AggregateWindowFunction]]s. */ object ResolveWindowOrder extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + s"ORDER BY window_ordering) from table") case WindowExpression(rank: RankLike, spec) if spec.resolved => val order = spec.orderSpec.map(_.child) WindowExpression(rank.withOrder(order), spec) } } } /** * Removes natural or using joins by calculating output columns based on output from two sides, * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) commonNaturalJoinProcessing(left, right, joinType, joinNames, condition) } } private def commonNaturalJoinProcessing( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, joinNames: Seq[String], condition: Option[Expression]) = { val leftKeys = joinNames.map { keyName => left.output.find(attr => resolver(attr.name, keyName)).getOrElse { throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " + s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]") } } val rightKeys = joinNames.map { keyName => right.output.find(attr => resolver(attr.name, keyName)).getOrElse { throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " + s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]") } } val joinPairs = leftKeys.zip(rightKeys) val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) // columns not in joinPairs val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) // the output list looks like: join keys, columns from left, columns from right val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) case LeftExistence(_) => leftKeys ++ lUniqueOutput case RightOuter => rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => // in full outer join, joinCols should be non-null if there is. val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) case _ : InnerLike => leftKeys ++ lUniqueOutput ++ rUniqueOutput case _ => sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields Project(projectList, Join(left, right, joinType, newCondition)) } /** * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p case p => p transformExpressions { case UnresolvedDeserializer(deserializer, inputAttributes) => val inputs = if (inputAttributes.isEmpty) { p.children.flatMap(_.output) } else { inputAttributes } validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpression( deserializer, LocalRelation(inputs), throws = true) val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { case ArrayType(et, cn) => val expr = MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } expr case other => throw new AnalysisException("need an array field but got " + other.simpleString) } } validateNestedTupleFields(result) result } } private def fail(schema: StructType, maxOrdinal: Int): Unit = { throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + "but failed as the number of fields does not line up.") } /** * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column * by position. However, the actual number of columns may be different from the number of Tuple * fields. This method is used to check the number of columns and fields, and throw an * exception if they do not match. */ private def validateTopLevelTupleFields( deserializer: Expression, inputs: Seq[Attribute]): Unit = { val ordinals = deserializer.collect { case GetColumnByOrdinal(ordinal, _) => ordinal }.distinct.sorted if (ordinals.nonEmpty && ordinals != inputs.indices) { fail(inputs.toStructType, ordinals.last) } } /** * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field * by position. However, the actual number of struct fields may be different from the number * of nested Tuple fields. This method is used to check the number of struct fields and nested * Tuple fields, and throw an exception if they do not match. */ private def validateNestedTupleFields(deserializer: Expression): Unit = { val structChildToOrdinals = deserializer // There are 2 kinds of `GetStructField`: // 1. resolved from `UnresolvedExtractValue`, and it will have a `name` property. // 2. created when we build deserializer expression for nested tuple, no `name` property. // Here we want to validate the ordinals of nested tuple, so we should only catch // `GetStructField` without the name property. .collect { case g: GetStructField if g.name.isEmpty => g } .groupBy(_.child) .mapValues(_.map(_.ordinal).distinct.sorted) structChildToOrdinals.foreach { case (expr, ordinals) => val schema = expr.dataType.asInstanceOf[StructType] if (ordinals != schema.indices) { fail(schema, ordinals.last) } } } } /** * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p case p => p transformExpressions { case n: NewInstance if n.childrenResolved && !n.resolved => val outer = OuterScopes.getOuterScope(n.cls) if (outer == null) { throw new AnalysisException( s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + "access to the scope that this class was defined in.\n" + "Try moving this class out of its parent class.") } n.copy(outerPointer = Some(outer)) } } } /** * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. */ object ResolveUpCast extends Rule[LogicalPlan] { private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { val fromStr = from match { case l: LambdaVariable => "array element" case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u case UpCast(child, dataType, walkedTypePath) if Cast.mayTruncate(child.dataType, dataType) => fail(child, dataType, walkedTypePath) case UpCast(child, dataType, walkedTypePath) => Cast(child, dataType.asNullable) } } } /** * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local * time zone. */ object ResolveTimeZone extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => e.withTimeZone(conf.sessionLocalTimeZone) // Casts could be added in the subquery plan through the rule TypeCoercion while coercing // the types between the value expression and list query expression of IN expression. // We need to subject the subquery plan through ResolveTimeZone again to setup timezone // information for time zone aware expressions. case e: ListQuery => e.withNewPlan(apply(e.plan)) } } } /** * Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case SubqueryAlias(_, child) => child } } /** * Removes [[Union]] operators from the plan if it just has one child. */ object EliminateUnions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Union(children) if children.size == 1 => children.head } } /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or * Window(window expressions). */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child } } def trimNonTopLevelAliases(e: Expression): Expression = e match { case a: Alias => a.withNewChildren(trimAliases(a.child) :: Nil) case other => trimAliases(other) } override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Project(cleanedProjectList, child) case Aggregate(grouping, aggs, child) => val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) case w @ Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. case o: ObjectConsumer => o case o: ObjectProducer => o case a: AppendColumns => a case other => other transformExpressionsDown { case Alias(child, _) => child } } } /** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and * filter out the rows where the time column is not inside the time window. */ object TimeWindowing extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.dsl.expressions._ private final val WINDOW_START = "start" private final val WINDOW_END = "end" /** * Generates the logical plan for generating window ranges on a timestamp column. Without * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many * window ranges a timestamp will map to given all possible combinations of a window duration, * slide duration and start time (offset). Therefore, we express and over-estimate the number of * windows there may be, and filter the valid windows. We use last Project operator to group * the window columns into a struct so they can be accessed as `window.start` and `window.end`. * * The windows are calculated as below: * maxNumOverlapping <- ceil(windowDuration / slideDuration) * for (i <- 0 until maxNumOverlapping) * windowId <- ceil((timestamp - startTime) / slideDuration) * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime * windowEnd <- windowStart + windowDuration * return windowStart, windowEnd * * This behaves as follows for the given parameters for the time: 12:05. The valid windows are * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the * Filter operator. * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m * 11:55 - 12:07 + 11:52 - 12:04 x * 12:00 - 12:12 + 11:57 - 12:09 + * 12:05 - 12:17 + 12:02 - 12:14 + * * @param plan The logical plan * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. // Only support a single window expression for now if (windowExpressions.size == 1 && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { val window = windowExpressions.head val metadata = window.timeColumn match { case a: Attribute => a.metadata case _ => Metadata.empty } val windowAttr = AttributeReference("window", window.dataType, metadata = metadata)() val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt val windows = Seq.tabulate(maxNumOverlapping + 1) { i => val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / window.slideDuration) val windowStart = (windowId + i - maxNumOverlapping) * window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( Literal(WINDOW_START) :: windowStart :: Literal(WINDOW_END) :: windowEnd :: Nil) } val projections = windows.map(_ +: p.children.head.output) val filterExpr = window.timeColumn >= windowAttr.getField(WINDOW_START) && window.timeColumn < windowAttr.getField(WINDOW_END) val expandedPlan = Filter(filterExpr, Expand(projections, windowAttr +: child.output, child)) val substitutedPlan = p transformExpressions { case t: TimeWindow => windowAttr } substitutedPlan.withNewChildren(expandedPlan :: Nil) } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + "of rows, therefore they are currently not supported.") } else { p // Return unchanged. Analyzer will throw exception later } } } /** * Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s. */ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { case e: CreateNamedStruct if !e.resolved => val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => Seq(Literal(e.name), e) case kv => kv } CreateNamedStruct(children.toList) } } /** * The aggregate expressions from subquery referencing outer query block are pushed * down to the outer query block for evaluation. This rule below updates such outer references * as AttributeReference referring attributes from the parent/outer query block. * * For example (SQL): * {{{ * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) * }}} * Plan before the rule. * Project [a#226] * +- Filter exists#245 [min(b#227)#249] * : +- Project [1 AS 1#247] * : +- Filter (d#238 < min(outer(b#227))) <----- * : +- SubqueryAlias r * : +- Project [_1#234 AS c#237, _2#235 AS d#238] * : +- LocalRelation [_1#234, _2#235] * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] * +- SubqueryAlias l * +- Project [_1#223 AS a#226, _2#224 AS b#227] * +- LocalRelation [_1#223, _2#224] * Plan after the rule. * Project [a#226] * +- Filter exists#245 [min(b#227)#249] * : +- Project [1 AS 1#247] * : +- Filter (d#238 < outer(min(b#227)#249)) <----- * : +- SubqueryAlias r * : +- Project [_1#234 AS c#237, _2#235 AS d#238] * : +- LocalRelation [_1#234, _2#235] * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] * +- SubqueryAlias l * +- Project [_1#223 AS a#226, _2#224 AS b#227] * +- LocalRelation [_1#223, _2#224] */ object UpdateOuterReferences extends Rule[LogicalPlan] { private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } private def updateOuterReferenceInSubquery( plan: LogicalPlan, refExprs: Seq[Expression]): LogicalPlan = { plan transformAllExpressions { case e => val outerAlias = refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) outerAlias match { case Some(a: Alias) => OuterReference(a.toAttribute) case _ => e } } } def apply(plan: LogicalPlan): LogicalPlan = { plan transform { case f @ Filter(_, a: Aggregate) if f.resolved => f transformExpressions { case s: SubqueryExpression if s.children.nonEmpty => // Collect the aliases from output of aggregate. val outerAliases = a.aggregateExpressions collect { case a: Alias => a } // Update the subquery plan to record the OuterReference to point to outer query plan. s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) } } } }