aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-03-13 12:02:52 +0800
committerCheng Lian <lian@databricks.com>2016-03-13 12:02:52 +0800
commitc079420d7c55d8972db716a2695a5ddd606d11cd (patch)
treecf33f0c9895ff28d2ce686c210d0d4afd7f6b78d /sql/core/src
parent4eace4d384f0e12b4934019d8654b5e3886ddaef (diff)
downloadspark-c079420d7c55d8972db716a2695a5ddd606d11cd.tar.gz
spark-c079420d7c55d8972db716a2695a5ddd606d11cd.tar.bz2
spark-c079420d7c55d8972db716a2695a5ddd606d11cd.zip
[SPARK-13841][SQL] Removes Dataset.collectRows()/takeRows()
## What changes were proposed in this pull request? This PR removes two methods, `collectRows()` and `takeRows()`, from `Dataset[T]`. These methods were added in PR #11443, and were later considered not useful. ## How was this patch tested? Existing tests should do the work. Author: Cheng Lian <lian@databricks.com> Closes #11678 from liancheng/remove-collect-rows-and-take-rows.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala18
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java4
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java39
3 files changed, 22 insertions, 39 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f1791e6943..1ea7db0388 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1762,10 +1762,6 @@ class Dataset[T] private[sql](
*/
def take(n: Int): Array[T] = head(n)
- def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds =>
- ds.collectRows(needCallback = false)
- }
-
/**
* Returns the first `n` rows in the [[DataFrame]] as a list.
*
@@ -1790,8 +1786,6 @@ class Dataset[T] private[sql](
*/
def collect(): Array[T] = collect(needCallback = true)
- def collectRows(): Array[Row] = collectRows(needCallback = true)
-
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
*
@@ -1820,18 +1814,6 @@ class Dataset[T] private[sql](
}
}
- private def collectRows(needCallback: Boolean): Array[Row] = {
- def execute(): Array[Row] = withNewExecutionId {
- queryExecution.executedPlan.executeCollectPublic()
- }
-
- if (needCallback) {
- withCallback("collect", toDF())(_ => execute())
- } else {
- execute()
- }
- }
-
/**
* Returns the number of rows in the [[DataFrame]].
* @group action
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 42af813bc1..ae9c8cc1ba 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -109,13 +109,13 @@ public class JavaApplySchemaSuite implements Serializable {
Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
df.registerTempTable("people");
- Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows();
+ List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList();
List<Row> expected = new ArrayList<>(2);
expected.add(RowFactory.create("Michael", 29));
expected.add(RowFactory.create("Yin", 28));
- Assert.assertEquals(expected, Arrays.asList(actual));
+ Assert.assertEquals(expected, actual);
}
@Test
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 47cc74dbc1..42554720ed 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -19,6 +19,7 @@ package test.org.apache.spark.sql;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
@@ -65,7 +66,7 @@ public class JavaDataFrameSuite {
@Test
public void testExecution() {
Dataset<Row> df = context.table("testData").filter("key = 1");
- Assert.assertEquals(1, df.select("key").collectRows()[0].get(0));
+ Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0));
}
@Test
@@ -208,8 +209,8 @@ public class JavaDataFrameSuite {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));
Dataset<Row> df = context.createDataFrame(rows, schema);
- Row[] result = df.collectRows();
- Assert.assertEquals(1, result.length);
+ List<Row> result = df.collectAsList();
+ Assert.assertEquals(1, result.size());
}
@Test
@@ -241,8 +242,8 @@ public class JavaDataFrameSuite {
Assert.assertEquals("a_b", columnNames[0]);
Assert.assertEquals("2", columnNames[1]);
Assert.assertEquals("1", columnNames[2]);
- Row[] rows = crosstab.collectRows();
- Arrays.sort(rows, crosstabRowComparator);
+ List<Row> rows = crosstab.collectAsList();
+ Collections.sort(rows, crosstabRowComparator);
Integer count = 1;
for (Row row : rows) {
Assert.assertEquals(row.get(0).toString(), count.toString());
@@ -257,7 +258,7 @@ public class JavaDataFrameSuite {
Dataset<Row> df = context.table("testData2");
String[] cols = {"a"};
Dataset<Row> results = df.stat().freqItems(cols, 0.2);
- Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1));
+ Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1));
}
@Test
@@ -278,27 +279,27 @@ public class JavaDataFrameSuite {
public void testSampleBy() {
Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
- Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows();
- Assert.assertEquals(0, actual[0].getLong(0));
- Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8);
- Assert.assertEquals(1, actual[1].getLong(0));
- Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
+ List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList();
+ Assert.assertEquals(0, actual.get(0).getLong(0));
+ Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8);
+ Assert.assertEquals(1, actual.get(1).getLong(0));
+ Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13);
}
@Test
public void pivot() {
Dataset<Row> df = context.table("courseSales");
- Row[] actual = df.groupBy("year")
+ List<Row> actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
- .agg(sum("earnings")).orderBy("year").collectRows();
+ .agg(sum("earnings")).orderBy("year").collectAsList();
- Assert.assertEquals(2012, actual[0].getInt(0));
- Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
- Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);
+ Assert.assertEquals(2012, actual.get(0).getInt(0));
+ Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01);
+ Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01);
- Assert.assertEquals(2013, actual[1].getInt(0));
- Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
- Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
+ Assert.assertEquals(2013, actual.get(1).getInt(0));
+ Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01);
+ Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01);
}
@Test