aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-08-09 10:58:36 -0700
committerReynold Xin <rxin@databricks.com>2015-08-09 10:58:36 -0700
commite9c36938ba972b6fe3c9f6228508e3c9f1c876b2 (patch)
tree20ecaccaa7f3e1e1cd97b6246afffe2cc6d37abc /sql
parent3ca995b78f373251081f6877623649bfba3040b2 (diff)
downloadspark-e9c36938ba972b6fe3c9f6228508e3c9f1c876b2.tar.gz
spark-e9c36938ba972b6fe3c9f6228508e3c9f1c876b2.tar.bz2
spark-e9c36938ba972b6fe3c9f6228508e3c9f1c876b2.zip
[SPARK-9752][SQL] Support UnsafeRow in Sample operator.
In order for this to work, I had to disable gap sampling. Author: Reynold Xin <rxin@databricks.com> Closes #8040 from rxin/SPARK-9752 and squashes the following commits: f9e248c [Reynold Xin] Fix the test case for real this time. adbccb3 [Reynold Xin] Fixed test case. 589fb23 [Reynold Xin] Merge branch 'SPARK-9752' of github.com:rxin/spark into SPARK-9752 55ccddc [Reynold Xin] Fixed core test. 78fa895 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator. c9e7112 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala35
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala17
3 files changed, 49 insertions, 21 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 0680f31d40..c5d1ed0937 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
@@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
+import org.apache.spark.util.random.PoissonSampler
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -130,12 +131,21 @@ case class Sample(
{
override def output: Seq[Attribute] = child.output
- // TODO: How to pick seed?
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
- child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+ // Disable gap sampling since the gap sampling method buffers two rows internally,
+ // requiring us to copy the row, which is more expensive than the random number generator.
+ new PartitionwiseSampledRDD[InternalRow, InternalRow](
+ child.execute(),
+ new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
+ preservesPartitioning = true,
+ seed)
} else {
- child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+ child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0e7659f443..8f5984e4a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest {
private def toLetter(i: Int): String = (i + 97).toChar.toString
+ test("sample with replacement") {
+ val n = 100
+ val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ checkAnswer(
+ data.sample(withReplacement = true, 0.05, seed = 13),
+ Seq(5, 10, 52, 73).map(Row(_))
+ )
+ }
+
+ test("sample without replacement") {
+ val n = 100
+ val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ checkAnswer(
+ data.sample(withReplacement = false, 0.05, seed = 13),
+ Seq(16, 23, 88, 100).map(Row(_))
+ )
+ }
+
+ test("randomSplit") {
+ val n = 600
+ val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ for (seed <- 1 to 5) {
+ val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+ assert(splits.length == 3, "wrong number of splits")
+
+ assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
+ data.collect().toList, "incomplete or wrong split")
+
+ val s = splits.map(_.count())
+ assert(math.abs(s(0) - 100) < 50) // std = 9.13
+ assert(math.abs(s(1) - 200) < 50) // std = 11.55
+ assert(math.abs(s(2) - 300) < 50) // std = 12.25
+ }
+ }
+
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f9cc6d1f3c..0212637a82 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
- test("randomSplit") {
- val n = 600
- val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id")
- for (seed <- 1 to 5) {
- val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
- assert(splits.length == 3, "wrong number of splits")
-
- assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
- data.collect().toList, "incomplete or wrong split")
-
- val s = splits.map(_.count())
- assert(math.abs(s(0) - 100) < 50) // std = 9.13
- assert(math.abs(s(1) - 200) < 50) // std = 11.55
- assert(math.abs(s(2) - 300) < 50) // std = 12.25
- }
- }
-
test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),