aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py12
-rw-r--r--python/pyspark/sql/functions.py20
-rw-r--r--python/pyspark/sql/tests.py15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala117
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala60
10 files changed, 223 insertions, 51 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 82cb1c2fdb..2ed95ac8e2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1511,13 +1511,19 @@ class Column(object):
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
- def alias(self, alias):
- """Return a alias for this column
+ def alias(self, *alias):
+ """Returns this column aliased with a new name or names (in the case of expressions that
+ return more than one column, such as explode).
>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
"""
- return Column(getattr(self._jc, "as")(alias))
+
+ if len(alias) == 1:
+ return Column(getattr(self._jc, "as")(alias[0]))
+ else:
+ sc = SparkContext._active_spark_context
+ return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
@ignore_unicode_prefix
def cast(self, dataType):
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d91265ee0b..6cd6974b0e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -169,6 +169,26 @@ def approxCountDistinct(col, rsd=None):
return Column(jc)
+def explode(col):
+ """Returns a new row for each element in the given array or map.
+
+ >>> from pyspark.sql import Row
+ >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
+ >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
+ [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
+
+ >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | a| b|
+ +---+-----+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.explode(_to_java_column(col))
+ return Column(jc)
+
+
def coalesce(*cols):
"""Returns the first column that is not null.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1922d03af6..d37c5dbed7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_explode(self):
+ from pyspark.sql.functions import explode
+ d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ result = data.select(explode(data.intlist).alias("a")).select("a").collect()
+ self.assertEqual(result[0][0], 1)
+ self.assertEqual(result[1][0], 2)
+ self.assertEqual(result[2][0], 3)
+
+ result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
+ self.assertEqual(result[0][0], "a")
+ self.assertEqual(result[0][1], "b")
+
def test_udf_with_callable(self):
d = [Row(number=i, squared=i**2) for i in range(10)]
rdd = self.sc.parallelize(d)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4baeeb5b58..0b6e1d44b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -73,7 +73,6 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
- ImplicitGenerate ::
ResolveFunctions ::
ExtractWindowExpressions ::
GlobalAggregates ::
@@ -323,6 +322,11 @@ class Analyzer(
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
+ case oldVersion: Generate
+ if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+ val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+ (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
case oldVersion @ Window(_, windowExpressions, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
@@ -521,66 +525,89 @@ class Analyzer(
}
/**
- * When a SELECT clause has only a single expression and that expression is a
- * [[catalyst.expressions.Generator Generator]] we convert the
- * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]].
+ * Rewrites table generating expressions that either need one or more of the following in order
+ * to be resolved:
+ * - concrete attribute references for their output.
+ * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]).
+ *
+ * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions
+ * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an
+ * [[AnalysisException]] is throw.
*/
- object ImplicitGenerate extends Rule[LogicalPlan] {
+ object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Project(Seq(Alias(g: Generator, name)), child) =>
- Generate(g, join = false, outer = false,
- qualifier = None, UnresolvedAttribute(name) :: Nil, child)
- case Project(Seq(MultiAlias(g: Generator, names)), child) =>
- Generate(g, join = false, outer = false,
- qualifier = None, names.map(UnresolvedAttribute(_)), child)
+ case p: Generate if !p.child.resolved || !p.generator.resolved => p
+ case g: Generate if g.resolved == false =>
+ g.copy(
+ generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
+
+ case p @ Project(projectList, child) =>
+ // Holds the resolved generator, if one exists in the project list.
+ var resolvedGenerator: Generate = null
+
+ val newProjectList = projectList.flatMap {
+ case AliasedGenerator(generator, names) if generator.childrenResolved =>
+ if (resolvedGenerator != null) {
+ failAnalysis(
+ s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " +
+ s"and ${generator.nodeName} found.")
+ }
+
+ resolvedGenerator =
+ Generate(
+ generator,
+ join = projectList.size > 1, // Only join if there are other expressions in SELECT.
+ outer = false,
+ qualifier = None,
+ generatorOutput = makeGeneratorOutput(generator, names),
+ child)
+
+ resolvedGenerator.generatorOutput
+ case other => other :: Nil
+ }
+
+ if (resolvedGenerator != null) {
+ Project(newProjectList, resolvedGenerator)
+ } else {
+ p
+ }
}
- }
- /**
- * Resolve the Generate, if the output names specified, we will take them, otherwise
- * we will try to provide the default names, which follow the same rule with Hive.
- */
- object ResolveGenerate extends Rule[LogicalPlan] {
- // Construct the output attributes for the generator,
- // The output attribute names can be either specified or
- // auto generated.
+ /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
+ private object AliasedGenerator {
+ def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
+ case Alias(g: Generator, name) => Some((g, name :: Nil))
+ case MultiAlias(g: Generator, names) => Some(g, names)
+ case _ => None
+ }
+ }
+
+ /**
+ * Construct the output attributes for a [[Generator]], given a list of names. If the list of
+ * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults.
+ */
private def makeGeneratorOutput(
generator: Generator,
- generatorOutput: Seq[Attribute]): Seq[Attribute] = {
+ names: Seq[String]): Seq[Attribute] = {
val elementTypes = generator.elementTypes
- if (generatorOutput.length == elementTypes.length) {
- generatorOutput.zip(elementTypes).map {
- case (a, (t, nullable)) if !a.resolved =>
- AttributeReference(a.name, t, nullable)()
- case (a, _) => a
+ if (names.length == elementTypes.length) {
+ names.zip(elementTypes).map {
+ case (name, (t, nullable)) =>
+ AttributeReference(name, t, nullable)()
}
- } else if (generatorOutput.length == 0) {
+ } else if (names.isEmpty) {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
} else {
- throw new AnalysisException(
- s"""
- |The number of aliases supplied in the AS clause does not match
- |the number of columns output by the UDTF expected
- |${elementTypes.size} aliases but got ${generatorOutput.size}
- """.stripMargin)
+ failAnalysis(
+ "The number of aliases supplied in the AS clause does not match the number of columns " +
+ s"output by the UDTF expected ${elementTypes.size} aliases but got " +
+ s"${names.mkString(",")} ")
}
}
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p: Generate if !p.child.resolved || !p.generator.resolved => p
- case p: Generate if p.resolved == false =>
- // if the generator output names are not specified, we will use the default ones.
- Generate(
- p.generator,
- join = p.join,
- outer = p.outer,
- p.qualifier,
- makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
- }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 0f349f9d11..01f4b6e9bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -59,6 +59,9 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {
+ /** The set of all attributes produced by this node. */
+ def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 6f2f35564d..e1d6ac462f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
StructField("cField", StringType) :: Nil
))())
+ val listRelation = LocalRelation(
+ AttributeReference("list", ArrayType(IntegerType))())
+
before {
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
@@ -159,11 +162,16 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
}
}
- errorMessages.foreach(m => assert(error.getMessage contains m))
+ errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
}
}
errorTest(
+ "too many generators",
+ listRelation.select(Explode('list).as('a), Explode('list).as('b)),
+ "only one generator" :: "explode" :: Nil)
+
+ errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bf1320ccb..dc0aeea7c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -18,12 +18,13 @@
package org.apache.spark.sql
import scala.language.implicitConversions
+import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
import org.apache.spark.sql.types._
@@ -728,6 +729,30 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def as(alias: String): Column = Alias(expr, alias)()
/**
+ * (Scala-specific) Assigns the given aliases to the results of a table generating function.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select(explode($"myMap").as("key" :: "value" :: Nil))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases)
+
+ /**
+ * Assigns the given aliases to the results of a table generating function.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select(explode($"myMap").as("key" :: "value" :: Nil))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def as(aliases: Array[String]): Column = MultiAlias(expr, aliases)
+
+ /**
* Gives the column an alias.
* {{{
* // Renames colA to colB in select output.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4fd5105c27..2e20c3d3f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,7 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -593,6 +593,9 @@ class DataFrame private[sql](
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
case Column(expr: NamedExpression) => expr
+ // Leave an unaliased explode with an empty list of names since the analzyer will generate the
+ // correct defaults after the nested expression's type has been resolved.
+ case Column(explode: Explode) => MultiAlias(explode, Nil)
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
}
// When user continuously call `select`, speed up analysis by collapsing `Project`
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4404ad8ad6..6640631cf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -364,6 +364,11 @@ object functions {
def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
/**
+ * Creates a new row for each element in the given array or map column.
+ */
+ def explode(e: Column): Column = Explode(e.expr)
+
+ /**
* Converts a string exprsesion to lower case.
*
* @group normal_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 269e185543..9bdf201b3b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -27,6 +27,66 @@ import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._
+ test("single explode") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(explode('intList)),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+ }
+
+ test("explode and other columns") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select($"a", explode('intList)),
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(1, 3) :: Nil)
+
+ checkAnswer(
+ df.select($"*", explode('intList)),
+ Row(1, Seq(1,2,3), 1) ::
+ Row(1, Seq(1,2,3), 2) ::
+ Row(1, Seq(1,2,3), 3) :: Nil)
+ }
+
+ test("aliased explode") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select('int),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select(sum('int)),
+ Row(6) :: Nil)
+ }
+
+ test("explode on map") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map)),
+ Row("a", "b"))
+ }
+
+ test("explode on map with aliases") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
+ Row("a", "b"))
+ }
+
+ test("self join explode") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+ val exploded = df.select(explode('intList).as('i))
+
+ checkAnswer(
+ exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
+ Row(3) :: Nil)
+ }
+
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))