aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-02-12 15:19:19 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-12 15:19:19 -0800
commitee04a8b19be8330bfc48f470ef365622162c915f (patch)
tree9224e8284d7e991f9f310fe1d1e8d4299908497f /sql/core
parentc352ffbdb9112714c176a747edff6115e9369e58 (diff)
downloadspark-ee04a8b19be8330bfc48f470ef365622162c915f.tar.gz
spark-ee04a8b19be8330bfc48f470ef365622162c915f.tar.bz2
spark-ee04a8b19be8330bfc48f470ef365622162c915f.zip
[SPARK-5573][SQL] Add explode to dataframes
Author: Michael Armbrust <michael@databricks.com> Closes #4546 from marmbrus/explode and squashes the following commits: eefd33a [Michael Armbrust] whitespace a8d496c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into explode 4af740e [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explode dc86a5c [Michael Armbrust] simple version d633d01 [Michael Armbrust] add scala specific 950707a [Michael Armbrust] fix comments ba8854c [Michael Armbrust] [SPARK-5573][SQL] Add explode to dataframes
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala25
4 files changed, 100 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 13aff760e9..65257882f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.spark.annotation.{DeveloperApi, Experimental}
@@ -441,6 +442,43 @@ trait DataFrame extends RDDApi[Row] with Serializable {
sample(withReplacement, fraction, Utils.random.nextLong)
}
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
+ * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
+ * the input row are implicitly joined with each row that is output by the function.
+ *
+ * The following example uses this function to count the number of books which contain
+ * a given word:
+ *
+ * {{{
+ * case class Book(title: String, words: String)
+ * val df: RDD[Book]
+ *
+ * case class Word(word: String)
+ * val allWords = df.explode('words) {
+ * case Row(words: String) => words.split(" ").map(Word(_))
+ * }
+ *
+ * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
+ * }}}
+ */
+ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame
+
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
+ * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
+ * columns of the input row are implicitly joined with each value that is output by the function.
+ *
+ * {{{
+ * df.explode("words", "word")(words: String => words.split(" "))
+ * }}}
+ */
+ def explode[A, B : TypeTag](
+ inputColumn: String,
+ outputColumn: String)(
+ f: A => TraversableOnce[B]): DataFrame
+
/////////////////////////////////////////////////////////////////////////////
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 4c6e19cace..bb5c6226a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -21,6 +21,7 @@ import java.io.CharArrayWriter
import scala.language.implicitConversions
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConversions._
import com.fasterxml.jackson.core.JsonFactory
@@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
+import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}
-
/**
* Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
*/
@@ -282,6 +282,32 @@ private[sql] class DataFrameImpl protected[sql](
Sample(fraction, withReplacement, seed, logicalPlan)
}
+ override def explode[A <: Product : TypeTag]
+ (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributes = schema.toAttributes
+ val rowFunction =
+ f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
+ val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
+
+ Generate(generator, join = true, outer = false, None, logicalPlan)
+ }
+
+ override def explode[A, B : TypeTag](
+ inputColumn: String,
+ outputColumn: String)(
+ f: A => TraversableOnce[B]): DataFrame = {
+ val dataType = ScalaReflection.schemaFor[B].dataType
+ val attributes = AttributeReference(outputColumn, dataType)() :: Nil
+ def rowFunction(row: Row) = {
+ f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
+ }
+ val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
+
+ Generate(generator, join = true, outer = false, None, logicalPlan)
+
+ }
+
/////////////////////////////////////////////////////////////////////////////
// RDD API
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 4f9d92d976..19c8e3b4f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
@@ -110,6 +111,14 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
+ override def explode[A <: Product : TypeTag]
+ (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()
+
+ override def explode[A, B : TypeTag](
+ inputColumn: String,
+ outputColumn: String)(
+ f: A => TraversableOnce[B]): DataFrame = err()
+
/////////////////////////////////////////////////////////////////////////////
override def head(n: Int): Array[Row] = err()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 7be9215a44..33b35f376b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -98,6 +98,31 @@ class DataFrameSuite extends QueryTest {
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}
+ test("simple explode") {
+ val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")
+
+ checkAnswer(
+ df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
+ Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
+ )
+ }
+
+ test("explode") {
+ val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
+ val df2 =
+ df.explode('letters) {
+ case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
+ }
+
+ checkAnswer(
+ df2
+ .select('_1 as 'letter, 'number)
+ .groupBy('letter)
+ .agg('letter, countDistinct('number)),
+ Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
+ )
+ }
+
test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),