aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-29 15:34:05 -0700
committerReynold Xin <rxin@databricks.com>2015-04-29 15:34:05 -0700
commitd7dbce8f7da8a7fd01df6633a6043f51161b7d18 (patch)
tree6e57f1d7614527f4796071a541171cd07f8d98d2 /sql/core
parentc9d530e2e5123dbd4fd13fc487c890d6076b24bf (diff)
downloadspark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.gz
spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.bz2
spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.zip
[SPARK-7156][SQL] support RandomSplit in DataFrames
This is built on top of kaka1992 's PR #5711 using Logical plans. Author: Burak Yavuz <brkyvz@gmail.com> Closes #5761 from brkyvz/random-sample and squashes the following commits: a1fb0aa [Burak Yavuz] remove unrelated file 69669c3 [Burak Yavuz] fix broken test 1ddb3da [Burak Yavuz] copy base 6000328 [Burak Yavuz] added python api and fixed test 3c11d1b [Burak Yavuz] fixed broken test f400ade [Burak Yavuz] fix build errors 2384266 [Burak Yavuz] addressed comments v0.1 e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala17
4 files changed, 74 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 0e896e5693..0d02e14c21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -706,7 +706,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
- Sample(fraction, withReplacement, seed, logicalPlan)
+ Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}
/**
@@ -721,6 +721,42 @@ class DataFrame private[sql](
}
/**
+ * Randomly splits this [[DataFrame]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ * @group dfops
+ */
+ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights.sliding(2).map { x =>
+ new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
+ }.toArray
+ }
+
+ /**
+ * Randomly splits this [[DataFrame]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @group dfops
+ */
+ def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+ randomSplit(weights, Utils.random.nextLong)
+ }
+
+ /**
+ * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ * @group dfops
+ */
+ def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+ randomSplit(weights.toArray, seed)
+ }
+
+ /**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index af58911cc0..326e8ce4ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
- case logical.Sample(fraction, withReplacement, seed, child) =>
- execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+ case logical.Sample(lb, ub, withReplacement, seed, child) =>
+ execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 1afdb40941..5ca11e67a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
/**
* :: DeveloperApi ::
+ * Sample the dataset.
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the QueryPlan
*/
@DeveloperApi
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
+case class Sample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output
// TODO: How to pick seed?
override def execute(): RDD[Row] = {
- child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
+ if (withReplacement) {
+ child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+ } else {
+ child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5ec06d448e..b70e127b4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}
+ test("randomSplit") {
+ val n = 600
+ val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
+ for (seed <- 1 to 5) {
+ val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+ assert(splits.length == 3, "wrong number of splits")
+
+ assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
+ data.collect().toList, "incomplete or wrong split")
+
+ val s = splits.map(_.count())
+ assert(math.abs(s(0) - 100) < 50) // std = 9.13
+ assert(math.abs(s(1) - 200) < 50) // std = 11.55
+ assert(math.abs(s(2) - 300) < 50) // std = 12.25
+ }
+ }
+
test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),