diff options
10 files changed, 223 insertions, 51 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 82cb1c2fdb..2ed95ac8e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1511,13 +1511,19 @@ class Column(object): isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - def alias(self, alias): - """Return a alias for this column + def alias(self, *alias): + """Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ - return Column(getattr(self._jc, "as")(alias)) + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) @ignore_unicode_prefix def cast(self, dataType): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d91265ee0b..6cd6974b0e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -169,6 +169,26 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1922d03af6..d37c5dbed7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4baeeb5b58..0b6e1d44b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -73,7 +73,6 @@ class Analyzer( ResolveGroupingAnalytics :: ResolveSortReferences :: ResolveGenerate :: - ImplicitGenerate :: ResolveFunctions :: ExtractWindowExpressions :: GlobalAggregates :: @@ -323,6 +322,11 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + case oldVersion @ Window(_, windowExpressions, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => @@ -521,66 +525,89 @@ class Analyzer( } /** - * When a SELECT clause has only a single expression and that expression is a - * [[catalyst.expressions.Generator Generator]] we convert the - * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]]. + * Rewrites table generating expressions that either need one or more of the following in order + * to be resolved: + * - concrete attribute references for their output. + * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * + * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions + * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an + * [[AnalysisException]] is throw. */ - object ImplicitGenerate extends Rule[LogicalPlan] { + object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, name)), child) => - Generate(g, join = false, outer = false, - qualifier = None, UnresolvedAttribute(name) :: Nil, child) - case Project(Seq(MultiAlias(g: Generator, names)), child) => - Generate(g, join = false, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), child) + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case g: Generate if g.resolved == false => + g.copy( + generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + + case p @ Project(projectList, child) => + // Holds the resolved generator, if one exists in the project list. + var resolvedGenerator: Generate = null + + val newProjectList = projectList.flatMap { + case AliasedGenerator(generator, names) if generator.childrenResolved => + if (resolvedGenerator != null) { + failAnalysis( + s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " + + s"and ${generator.nodeName} found.") + } + + resolvedGenerator = + Generate( + generator, + join = projectList.size > 1, // Only join if there are other expressions in SELECT. + outer = false, + qualifier = None, + generatorOutput = makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } + + if (resolvedGenerator != null) { + Project(newProjectList, resolvedGenerator) + } else { + p + } } - } - /** - * Resolve the Generate, if the output names specified, we will take them, otherwise - * we will try to provide the default names, which follow the same rule with Hive. - */ - object ResolveGenerate extends Rule[LogicalPlan] { - // Construct the output attributes for the generator, - // The output attribute names can be either specified or - // auto generated. + /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ + private object AliasedGenerator { + def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { + case Alias(g: Generator, name) => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) => Some(g, names) + case _ => None + } + } + + /** + * Construct the output attributes for a [[Generator]], given a list of names. If the list of + * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults. + */ private def makeGeneratorOutput( generator: Generator, - generatorOutput: Seq[Attribute]): Seq[Attribute] = { + names: Seq[String]): Seq[Attribute] = { val elementTypes = generator.elementTypes - if (generatorOutput.length == elementTypes.length) { - generatorOutput.zip(elementTypes).map { - case (a, (t, nullable)) if !a.resolved => - AttributeReference(a.name, t, nullable)() - case (a, _) => a + if (names.length == elementTypes.length) { + names.zip(elementTypes).map { + case (name, (t, nullable)) => + AttributeReference(name, t, nullable)() } - } else if (generatorOutput.length == 0) { + } else if (names.isEmpty) { elementTypes.zipWithIndex.map { // keep the default column names as Hive does _c0, _c1, _cN case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() } } else { - throw new AnalysisException( - s""" - |The number of aliases supplied in the AS clause does not match - |the number of columns output by the UDTF expected - |${elementTypes.size} aliases but got ${generatorOutput.size} - """.stripMargin) + failAnalysis( + "The number of aliases supplied in the AS clause does not match the number of columns " + + s"output by the UDTF expected ${elementTypes.size} aliases but got " + + s"${names.mkString(",")} ") } } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Generate if !p.child.resolved || !p.generator.resolved => p - case p: Generate if p.resolved == false => - // if the generator output names are not specified, we will use the default ones. - Generate( - p.generator, - join = p.join, - outer = p.outer, - p.qualifier, - makeGeneratorOutput(p.generator, p.generatorOutput), p.child) - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0f349f9d11..01f4b6e9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -59,6 +59,9 @@ case class Generate( child: LogicalPlan) extends UnaryNode { + /** The set of all attributes produced by this node. */ + def generatedSet: AttributeSet = AttributeSet(generatorOutput) + override lazy val resolved: Boolean = { generator.resolved && childrenResolved && diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f2f35564d..e1d6ac462f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { StructField("cField", StringType) :: Nil ))()) + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -159,11 +162,16 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } } - errorMessages.foreach(m => assert(error.getMessage contains m)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) } } errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + + errorTest( "unresolved attributes", testRelation.select('abcd), "cannot resolve" :: "abcd" :: Nil) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8bf1320ccb..dc0aeea7c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql import scala.language.implicitConversions +import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} import org.apache.spark.sql.types._ @@ -728,6 +729,30 @@ class Column(protected[sql] val expr: Expression) extends Logging { def as(alias: String): Column = Alias(expr, alias)() /** + * (Scala-specific) Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + + /** + * Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + + /** * Gives the column an alias. * {{{ * // Renames colA to colB in select output. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4fd5105c27..2e20c3d3f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,7 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -593,6 +593,9 @@ class DataFrame private[sql]( def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { case Column(expr: NamedExpression) => expr + // Leave an unaliased explode with an empty list of names since the analzyer will generate the + // correct defaults after the nested expression's type has been resolved. + case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() } // When user continuously call `select`, speed up analysis by collapsing `Project` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4404ad8ad6..6640631cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -364,6 +364,11 @@ object functions { def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) /** + * Creates a new row for each element in the given array or map column. + */ + def explode(e: Column): Column = Explode(e.expr) + + /** * Converts a string exprsesion to lower case. * * @group normal_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 269e185543..9bdf201b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -27,6 +27,66 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + test("single explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1,2,3), 1) :: + Row(1, Seq(1,2,3), 2) :: + Row(1, Seq(1,2,3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + + test("self join explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } + test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) |