From 751724b1320d38fd94186df3d8f1ca887f21947a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 25 Feb 2016 11:53:48 -0800 Subject: Revert "[SPARK-13457][SQL] Removes DataFrame RDD operations" This reverts commit 157fe64f3ecbd13b7286560286e50235eecfe30e. --- .../scala/org/apache/spark/sql/DataFrame.scala | 42 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/GroupedData.scala | 1 - .../org/apache/spark/sql/api/r/SQLUtils.scala | 2 +- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 2 +- .../apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- .../datasources/parquet/ParquetFilterSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 2 +- .../spark/sql/execution/ui/SQLListenerSuite.scala | 4 +-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- .../spark/sql/hive/execution/HiveQuerySuite.scala | 2 -- .../apache/spark/sql/hive/orc/OrcQuerySuite.scala | 5 ++- .../org/apache/spark/sql/hive/parquetSuites.scala | 2 +- 12 files changed, 53 insertions(+), 15 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index be902d688e..abb8fe552b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1426,6 +1426,48 @@ class DataFrame private[sql]( */ def transform[U](t: DataFrame => DataFrame): DataFrame = t(this) + /** + * Returns a new RDD by applying a function to all rows of this DataFrame. + * @group rdd + * @since 1.3.0 + */ + def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f) + + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + * @group rdd + * @since 1.3.0 + */ + def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + /** + * Returns a new RDD by applying a function to each partition of this DataFrame. + * @group rdd + * @since 1.3.0 + */ + def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Applies a function `f` to all rows. + * @group rdd + * @since 1.3.0 + */ + def foreach(f: Row => Unit): Unit = withNewExecutionId { + rdd.foreach(f) + } + + /** + * Applies a function f to each partition of this [[DataFrame]]. + * @group rdd + * @since 1.3.0 + */ + def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId { + rdd.foreachPartition(f) + } + /** * Returns the first `n` rows in the [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a7258d742a..f06d16116e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -306,7 +306,6 @@ class GroupedData protected[sql]( val values = df.select(pivotColumn) .distinct() .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .rdd .map(_.get(0)) .take(maxValues + 1) .toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 68a251757c..d912aeb70d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -100,7 +100,7 @@ private[r] object SQLUtils { } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { - df.rdd.map(r => rowToRBytes(r)) + df.map(r => rowToRBytes(r)) } private[this] def doConversion(data: Object, dataType: DataType): Object = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d7111a6a1c..e295722cac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -276,7 +276,7 @@ object JdbcUtils extends Logging { val rddSchema = df.schema val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt - df.rdd.foreachPartition { iterator => + df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7d96ef6fe0..f54bff9f18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -257,7 +257,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("count") { - assert(testData2.count() === testData2.rdd.map(_ => 1).count()) + assert(testData2.count() === testData2.map(_ => 1).count()) checkAnswer( testData2.agg(count('a), sumDistinct('a)), // non-partial diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index bd51154c58..fbffe867e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -101,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { - df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index c85eeddc2c..3c74464d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -599,7 +599,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data = sqlContext.range(200).rdd.map { i => + val data = sqlContext.range(200).map { i => if (i.getLong(0) < 150) Row(None) else Row("a") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 39920d8cc6..085e4a49a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -330,7 +330,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { // listener should ignore the non SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().rdd.foreach(i => ()) + sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) @@ -398,7 +398,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { ).toDF() df.collect() try { - df.rdd.foreach(_ => throw new RuntimeException("Oops")) + df.foreach(_ => throw new RuntimeException("Oops")) } catch { case e: SparkException => // This is expected for a failed job } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 12a5542bd4..f141a9bd0f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -210,7 +210,7 @@ object SparkSubmitClassLoaderTest extends Logging { } // Second, we load classes at the executor side. logInfo("Testing load classes at the executor side.") - val result = df.rdd.mapPartitions { x => + val result = df.mapPartitions { x => var exception: String = null try { Utils.classForName(args(0)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 1002487447..3208ebc9ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -664,13 +664,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") - .rdd .map { case Row(i: Int) => i } .collect() .toSet val expected = sql("SELECT key FROM src") - .rdd .map { case Row(i: Int) => i } .collect() .toSet diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 68249517f5..b11d1d9de0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -119,7 +119,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // expr = (not leaf-0) assertResult(10) { sql("SELECT name, contacts FROM t where age > 5") - .rdd .flatMap(_.getAs[Seq[_]]("contacts")) .count() } @@ -132,7 +131,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") assert(df.count() === 2) assertResult(4) { - df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.flatMap(_.getAs[Seq[_]]("contacts")).count() } } @@ -144,7 +143,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") assert(df.count() === 3) assertResult(6) { - df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count() + df.flatMap(_.getAs[Seq[_]]("contacts")).count() } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index a127cf6e4b..68d5c7da1f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -854,7 +854,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with test(s"hive udfs $table") { checkAnswer( sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").rdd.map { + sql(s"SELECT stringField FROM $table").map { case Row(s: String) => Row(s + s) }.collect().toSeq) } -- cgit v1.2.3