aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala17
4 files changed, 74 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 0e896e5693..0d02e14c21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -706,7 +706,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
- Sample(fraction, withReplacement, seed, logicalPlan)
+ Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}
/**
@@ -721,6 +721,42 @@ class DataFrame private[sql](
}
/**
+ * Randomly splits this [[DataFrame]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ * @group dfops
+ */
+ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights.sliding(2).map { x =>
+ new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
+ }.toArray
+ }
+
+ /**
+ * Randomly splits this [[DataFrame]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @group dfops
+ */
+ def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+ randomSplit(weights, Utils.random.nextLong)
+ }
+
+ /**
+ * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ * @group dfops
+ */
+ def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+ randomSplit(weights.toArray, seed)
+ }
+
+ /**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index af58911cc0..326e8ce4ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
- case logical.Sample(fraction, withReplacement, seed, child) =>
- execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+ case logical.Sample(lb, ub, withReplacement, seed, child) =>
+ execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 1afdb40941..5ca11e67a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
/**
* :: DeveloperApi ::
+ * Sample the dataset.
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the QueryPlan
*/
@DeveloperApi
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
+case class Sample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output
// TODO: How to pick seed?
override def execute(): RDD[Row] = {
- child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
+ if (withReplacement) {
+ child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+ } else {
+ child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5ec06d448e..b70e127b4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}
+ test("randomSplit") {
+ val n = 600
+ val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
+ for (seed <- 1 to 5) {
+ val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+ assert(splits.length == 3, "wrong number of splits")
+
+ assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
+ data.collect().toList, "incomplete or wrong split")
+
+ val s = splits.map(_.count())
+ assert(math.abs(s(0) - 100) < 50) // std = 9.13
+ assert(math.abs(s(1) - 200) < 50) // std = 11.55
+ assert(math.abs(s(2) - 300) < 50) // std = 12.25
+ }
+ }
+
test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),