aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-05-14 19:49:44 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-14 19:49:44 -0700
commit6d0633e3ec9518278fcc7eba58549d4ad3d5813f (patch)
tree4bd5a50dcd8af0b37c44180f607dcd8a8e5b26da /sql/catalyst
parent48fc38f5844f6c12bf440f2990b6d7f1630fafac (diff)
downloadspark-6d0633e3ec9518278fcc7eba58549d4ad3d5813f.tar.gz
spark-6d0633e3ec9518278fcc7eba58549d4ad3d5813f.tar.bz2
spark-6d0633e3ec9518278fcc7eba58549d4ad3d5813f.zip
[SPARK-7548] [SQL] Add explode function for DataFrames
Add an `explode` function for dataframes and modify the analyzer so that single table generating functions can be present in a select clause along with other expressions. There are currently the following restrictions: - only top level TGFs are allowed (i.e. no `select(explode('list) + 1)`) - only one may be present in a single select to avoid potentially confusing implicit Cartesian products. TODO: - [ ] Python Author: Michael Armbrust <michael@databricks.com> Closes #6107 from marmbrus/explodeFunction and squashes the following commits: 7ee2c87 [Michael Armbrust] whitespace 6f80ba3 [Michael Armbrust] Update dataframe.py c176c89 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction 81b5da3 [Michael Armbrust] style d3faa05 [Michael Armbrust] fix self join case f9e1e3e [Michael Armbrust] fix python, add since 4f0d0a9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction e710fe4 [Michael Armbrust] add java and python 52ca0dc [Michael Armbrust] [SPARK-7548][SQL] Add explode function for dataframes.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala117
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala10
3 files changed, 84 insertions, 46 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4baeeb5b58..0b6e1d44b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -73,7 +73,6 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
- ImplicitGenerate ::
ResolveFunctions ::
ExtractWindowExpressions ::
GlobalAggregates ::
@@ -323,6 +322,11 @@ class Analyzer(
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
+ case oldVersion: Generate
+ if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+ val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+ (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
case oldVersion @ Window(_, windowExpressions, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
@@ -521,66 +525,89 @@ class Analyzer(
}
/**
- * When a SELECT clause has only a single expression and that expression is a
- * [[catalyst.expressions.Generator Generator]] we convert the
- * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]].
+ * Rewrites table generating expressions that either need one or more of the following in order
+ * to be resolved:
+ * - concrete attribute references for their output.
+ * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]).
+ *
+ * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions
+ * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an
+ * [[AnalysisException]] is throw.
*/
- object ImplicitGenerate extends Rule[LogicalPlan] {
+ object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Project(Seq(Alias(g: Generator, name)), child) =>
- Generate(g, join = false, outer = false,
- qualifier = None, UnresolvedAttribute(name) :: Nil, child)
- case Project(Seq(MultiAlias(g: Generator, names)), child) =>
- Generate(g, join = false, outer = false,
- qualifier = None, names.map(UnresolvedAttribute(_)), child)
+ case p: Generate if !p.child.resolved || !p.generator.resolved => p
+ case g: Generate if g.resolved == false =>
+ g.copy(
+ generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
+
+ case p @ Project(projectList, child) =>
+ // Holds the resolved generator, if one exists in the project list.
+ var resolvedGenerator: Generate = null
+
+ val newProjectList = projectList.flatMap {
+ case AliasedGenerator(generator, names) if generator.childrenResolved =>
+ if (resolvedGenerator != null) {
+ failAnalysis(
+ s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " +
+ s"and ${generator.nodeName} found.")
+ }
+
+ resolvedGenerator =
+ Generate(
+ generator,
+ join = projectList.size > 1, // Only join if there are other expressions in SELECT.
+ outer = false,
+ qualifier = None,
+ generatorOutput = makeGeneratorOutput(generator, names),
+ child)
+
+ resolvedGenerator.generatorOutput
+ case other => other :: Nil
+ }
+
+ if (resolvedGenerator != null) {
+ Project(newProjectList, resolvedGenerator)
+ } else {
+ p
+ }
}
- }
- /**
- * Resolve the Generate, if the output names specified, we will take them, otherwise
- * we will try to provide the default names, which follow the same rule with Hive.
- */
- object ResolveGenerate extends Rule[LogicalPlan] {
- // Construct the output attributes for the generator,
- // The output attribute names can be either specified or
- // auto generated.
+ /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
+ private object AliasedGenerator {
+ def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
+ case Alias(g: Generator, name) => Some((g, name :: Nil))
+ case MultiAlias(g: Generator, names) => Some(g, names)
+ case _ => None
+ }
+ }
+
+ /**
+ * Construct the output attributes for a [[Generator]], given a list of names. If the list of
+ * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults.
+ */
private def makeGeneratorOutput(
generator: Generator,
- generatorOutput: Seq[Attribute]): Seq[Attribute] = {
+ names: Seq[String]): Seq[Attribute] = {
val elementTypes = generator.elementTypes
- if (generatorOutput.length == elementTypes.length) {
- generatorOutput.zip(elementTypes).map {
- case (a, (t, nullable)) if !a.resolved =>
- AttributeReference(a.name, t, nullable)()
- case (a, _) => a
+ if (names.length == elementTypes.length) {
+ names.zip(elementTypes).map {
+ case (name, (t, nullable)) =>
+ AttributeReference(name, t, nullable)()
}
- } else if (generatorOutput.length == 0) {
+ } else if (names.isEmpty) {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
} else {
- throw new AnalysisException(
- s"""
- |The number of aliases supplied in the AS clause does not match
- |the number of columns output by the UDTF expected
- |${elementTypes.size} aliases but got ${generatorOutput.size}
- """.stripMargin)
+ failAnalysis(
+ "The number of aliases supplied in the AS clause does not match the number of columns " +
+ s"output by the UDTF expected ${elementTypes.size} aliases but got " +
+ s"${names.mkString(",")} ")
}
}
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p: Generate if !p.child.resolved || !p.generator.resolved => p
- case p: Generate if p.resolved == false =>
- // if the generator output names are not specified, we will use the default ones.
- Generate(
- p.generator,
- join = p.join,
- outer = p.outer,
- p.qualifier,
- makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
- }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 0f349f9d11..01f4b6e9bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -59,6 +59,9 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {
+ /** The set of all attributes produced by this node. */
+ def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 6f2f35564d..e1d6ac462f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
StructField("cField", StringType) :: Nil
))())
+ val listRelation = LocalRelation(
+ AttributeReference("list", ArrayType(IntegerType))())
+
before {
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
@@ -159,11 +162,16 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
}
}
- errorMessages.foreach(m => assert(error.getMessage contains m))
+ errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
}
}
errorTest(
+ "too many generators",
+ listRelation.select(Explode('list).as('a), Explode('list).as('b)),
+ "only one generator" :: "explode" :: Nil)
+
+ errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)