From c079420d7c55d8972db716a2695a5ddd606d11cd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 13 Mar 2016 12:02:52 +0800 Subject: [SPARK-13841][SQL] Removes Dataset.collectRows()/takeRows() ## What changes were proposed in this pull request? This PR removes two methods, `collectRows()` and `takeRows()`, from `Dataset[T]`. These methods were added in PR #11443, and were later considered not useful. ## How was this patch tested? Existing tests should do the work. Author: Cheng Lian Closes #11678 from liancheng/remove-collect-rows-and-take-rows. --- .../scala/org/apache/spark/sql/DataFrame.scala | 18 ---------- .../org/apache/spark/sql/JavaApplySchemaSuite.java | 4 +-- .../org/apache/spark/sql/JavaDataFrameSuite.java | 39 +++++++++++----------- 3 files changed, 22 insertions(+), 39 deletions(-) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f1791e6943..1ea7db0388 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1762,10 +1762,6 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) - def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds => - ds.collectRows(needCallback = false) - } - /** * Returns the first `n` rows in the [[DataFrame]] as a list. * @@ -1790,8 +1786,6 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = collect(needCallback = true) - def collectRows(): Array[Row] = collectRows(needCallback = true) - /** * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. * @@ -1820,18 +1814,6 @@ class Dataset[T] private[sql]( } } - private def collectRows(needCallback: Boolean): Array[Row] = { - def execute(): Array[Row] = withNewExecutionId { - queryExecution.executedPlan.executeCollectPublic() - } - - if (needCallback) { - withCallback("collect", toDF())(_ => execute()) - } else { - execute() - } - } - /** * Returns the number of rows in the [[DataFrame]]. * @group action diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 42af813bc1..ae9c8cc1ba 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -109,13 +109,13 @@ public class JavaApplySchemaSuite implements Serializable { Dataset df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows(); + List actual = sqlContext.sql("SELECT * FROM people").collectAsList(); List expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); expected.add(RowFactory.create("Yin", 28)); - Assert.assertEquals(expected, Arrays.asList(actual)); + Assert.assertEquals(expected, actual); } @Test diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 47cc74dbc1..42554720ed 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -19,6 +19,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -65,7 +66,7 @@ public class JavaDataFrameSuite { @Test public void testExecution() { Dataset df = context.table("testData").filter("key = 1"); - Assert.assertEquals(1, df.select("key").collectRows()[0].get(0)); + Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test @@ -208,8 +209,8 @@ public class JavaDataFrameSuite { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List rows = Arrays.asList(RowFactory.create(0)); Dataset df = context.createDataFrame(rows, schema); - Row[] result = df.collectRows(); - Assert.assertEquals(1, result.length); + List result = df.collectAsList(); + Assert.assertEquals(1, result.size()); } @Test @@ -241,8 +242,8 @@ public class JavaDataFrameSuite { Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); - Row[] rows = crosstab.collectRows(); - Arrays.sort(rows, crosstabRowComparator); + List rows = crosstab.collectAsList(); + Collections.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); @@ -257,7 +258,7 @@ public class JavaDataFrameSuite { Dataset df = context.table("testData2"); String[] cols = {"a"}; Dataset results = df.stat().freqItems(cols, 0.2); - Assert.assertTrue(results.collectRows()[0].getSeq(0).contains(1)); + Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); } @Test @@ -278,27 +279,27 @@ public class JavaDataFrameSuite { public void testSampleBy() { Dataset df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); - Row[] actual = sampled.groupBy("key").count().orderBy("key").collectRows(); - Assert.assertEquals(0, actual[0].getLong(0)); - Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); - Assert.assertEquals(1, actual[1].getLong(0)); - Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); + List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); + Assert.assertEquals(0, actual.get(0).getLong(0)); + Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); + Assert.assertEquals(1, actual.get(1).getLong(0)); + Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); } @Test public void pivot() { Dataset df = context.table("courseSales"); - Row[] actual = df.groupBy("year") + List actual = df.groupBy("year") .pivot("course", Arrays.asList("dotNET", "Java")) - .agg(sum("earnings")).orderBy("year").collectRows(); + .agg(sum("earnings")).orderBy("year").collectAsList(); - Assert.assertEquals(2012, actual[0].getInt(0)); - Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); - Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); - Assert.assertEquals(2013, actual[1].getInt(0)); - Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); - Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); } @Test -- cgit v1.2.3