aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2016-02-25 11:53:48 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-25 11:53:48 -0800
commit751724b1320d38fd94186df3d8f1ca887f21947a (patch)
treeef365e952284a7ec26aee882caa429a729223c9d /sql
parent46f6e79316b72afea0c9b1559ea662dd3e95e57b (diff)
downloadspark-751724b1320d38fd94186df3d8f1ca887f21947a.tar.gz
spark-751724b1320d38fd94186df3d8f1ca887f21947a.tar.bz2
spark-751724b1320d38fd94186df3d8f1ca887f21947a.zip
Revert "[SPARK-13457][SQL] Removes DataFrame RDD operations"
This reverts commit 157fe64f3ecbd13b7286560286e50235eecfe30e.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala2
12 files changed, 53 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index be902d688e..abb8fe552b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1427,6 +1427,48 @@ class DataFrame private[sql](
def transform[U](t: DataFrame => DataFrame): DataFrame = t(this)
/**
+ * Returns a new RDD by applying a function to all rows of this DataFrame.
+ * @group rdd
+ * @since 1.3.0
+ */
+ def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)
+
+ /**
+ * Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
+ * and then flattening the results.
+ * @group rdd
+ * @since 1.3.0
+ */
+ def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
+
+ /**
+ * Returns a new RDD by applying a function to each partition of this DataFrame.
+ * @group rdd
+ * @since 1.3.0
+ */
+ def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
+ rdd.mapPartitions(f)
+ }
+
+ /**
+ * Applies a function `f` to all rows.
+ * @group rdd
+ * @since 1.3.0
+ */
+ def foreach(f: Row => Unit): Unit = withNewExecutionId {
+ rdd.foreach(f)
+ }
+
+ /**
+ * Applies a function f to each partition of this [[DataFrame]].
+ * @group rdd
+ * @since 1.3.0
+ */
+ def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId {
+ rdd.foreachPartition(f)
+ }
+
+ /**
* Returns the first `n` rows in the [[DataFrame]].
*
* Running take requires moving data into the application's driver process, and doing so with
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index a7258d742a..f06d16116e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -306,7 +306,6 @@ class GroupedData protected[sql](
val values = df.select(pivotColumn)
.distinct()
.sort(pivotColumn) // ensure that the output columns are in a consistent logical order
- .rdd
.map(_.get(0))
.take(maxValues + 1)
.toSeq
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 68a251757c..d912aeb70d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -100,7 +100,7 @@ private[r] object SQLUtils {
}
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
- df.rdd.map(r => rowToRBytes(r))
+ df.map(r => rowToRBytes(r))
}
private[this] def doConversion(data: Object, dataType: DataType): Object = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index d7111a6a1c..e295722cac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -276,7 +276,7 @@ object JdbcUtils extends Logging {
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
- df.rdd.foreachPartition { iterator =>
+ df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 7d96ef6fe0..f54bff9f18 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -257,7 +257,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
test("count") {
- assert(testData2.count() === testData2.rdd.map(_ => 1).count())
+ assert(testData2.count() === testData2.map(_ => 1).count())
checkAnswer(
testData2.agg(count('a), sumDistinct('a)), // non-partial
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index bd51154c58..fbffe867e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -101,7 +101,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
(implicit df: DataFrame): Unit = {
def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = {
assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) {
- df.rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
+ df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index c85eeddc2c..3c74464d57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -599,7 +599,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
test("null and non-null strings") {
// Create a dataset where the first values are NULL and then some non-null values. The
// number of non-nulls needs to be bigger than the ParquetReader batch size.
- val data = sqlContext.range(200).rdd.map { i =>
+ val data = sqlContext.range(200).map { i =>
if (i.getLong(0) < 150) Row(None)
else Row("a")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 39920d8cc6..085e4a49a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -330,7 +330,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
// listener should ignore the non SQL stage
assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber)
- sqlContext.sparkContext.parallelize(1 to 10).toDF().rdd.foreach(i => ())
+ sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ())
sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000)
// listener should save the SQL stage
assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1)
@@ -398,7 +398,7 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite {
).toDF()
df.collect()
try {
- df.rdd.foreach(_ => throw new RuntimeException("Oops"))
+ df.foreach(_ => throw new RuntimeException("Oops"))
} catch {
case e: SparkException => // This is expected for a failed job
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 12a5542bd4..f141a9bd0f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -210,7 +210,7 @@ object SparkSubmitClassLoaderTest extends Logging {
}
// Second, we load classes at the executor side.
logInfo("Testing load classes at the executor side.")
- val result = df.rdd.mapPartitions { x =>
+ val result = df.mapPartitions { x =>
var exception: String = null
try {
Utils.classForName(args(0))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 1002487447..3208ebc9ff 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -664,13 +664,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("implement identity function using case statement") {
val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src")
- .rdd
.map { case Row(i: Int) => i }
.collect()
.toSet
val expected = sql("SELECT key FROM src")
- .rdd
.map { case Row(i: Int) => i }
.collect()
.toSet
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index 68249517f5..b11d1d9de0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -119,7 +119,6 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
// expr = (not leaf-0)
assertResult(10) {
sql("SELECT name, contacts FROM t where age > 5")
- .rdd
.flatMap(_.getAs[Seq[_]]("contacts"))
.count()
}
@@ -132,7 +131,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8")
assert(df.count() === 2)
assertResult(4) {
- df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count()
+ df.flatMap(_.getAs[Seq[_]]("contacts")).count()
}
}
@@ -144,7 +143,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8")
assert(df.count() === 3)
assertResult(6) {
- df.rdd.flatMap(_.getAs[Seq[_]]("contacts")).count()
+ df.flatMap(_.getAs[Seq[_]]("contacts")).count()
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index a127cf6e4b..68d5c7da1f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -854,7 +854,7 @@ abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with
test(s"hive udfs $table") {
checkAnswer(
sql(s"SELECT concat(stringField, stringField) FROM $table"),
- sql(s"SELECT stringField FROM $table").rdd.map {
+ sql(s"SELECT stringField FROM $table").map {
case Row(s: String) => Row(s + s)
}.collect().toSeq)
}