aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala19
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java8
-rw-r--r--python/pyspark/sql/dataframe.py18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala17
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala4
10 files changed, 130 insertions, 22 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index d80d94a588..330255f892 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -407,12 +407,27 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](
- this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
+ randomSampleWithRange(x(0), x(1), seed)
}.toArray
}
/**
+ * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
+ * range.
+ * @param lb lower bound to use for the Bernoulli sampler
+ * @param ub upper bound to use for the Bernoulli sampler
+ * @param seed the seed for the Random number generator
+ * @return A random sub-sample of the RDD without replacement.
+ */
+ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
+ this.mapPartitionsWithIndex { case (index, partition) =>
+ val sampler = new BernoulliCellSampler[T](lb, ub)
+ sampler.setSeed(seed + index)
+ sampler.sample(partition)
+ }
+ }
+
+ /**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 34ac9361d4..c2089b0e56 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -157,11 +157,11 @@ public class JavaAPISuite implements Serializable {
public void randomSplit() {
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
JavaRDD<Integer> rdd = sc.parallelize(ints);
- JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
+ JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
Assert.assertEquals(3, splits.length);
- Assert.assertEquals(2, splits[0].count());
- Assert.assertEquals(3, splits[1].count());
- Assert.assertEquals(5, splits[2].count());
+ Assert.assertEquals(1, splits[0].count());
+ Assert.assertEquals(2, splits[1].count());
+ Assert.assertEquals(7, splits[2].count());
}
@Test
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d9cbbc68b3..3074af3ed2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -426,7 +426,7 @@ class DataFrame(object):
def sample(self, withReplacement, fraction, seed=None):
"""Returns a sampled subset of this :class:`DataFrame`.
- >>> df.sample(False, 0.5, 97).count()
+ >>> df.sample(False, 0.5, 42).count()
1
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
@@ -434,6 +434,22 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
+ def randomSplit(self, weights, seed=None):
+ """Randomly splits this :class:`DataFrame` with the provided weights.
+
+ >>> splits = df4.randomSplit([1.0, 2.0], 24)
+ >>> splits[0].count()
+ 1
+
+ >>> splits[1].count()
+ 3
+ """
+ for w in weights:
+ assert w >= 0.0, "Negative weight value: %s" % w
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
+ return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
+
@property
def dtypes(self):
"""Returns all column names and their data types as a list.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 5d5aba9644..fa6cc7a1a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -278,12 +278,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
- def sample(
- fraction: Double,
- withReplacement: Boolean = true,
- seed: Int = (math.random * 1000).toInt): LogicalPlan =
- Sample(fraction, withReplacement, seed, logicalPlan)
-
// TODO specify the output column names
def generate(
generator: Generator,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 608e272da7..21208c8a5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
- extends UnaryNode {
+/**
+ * Sample the dataset.
+ *
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the LogicalPlan
+ */
+case class Sample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 0e896e5693..0d02e14c21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -706,7 +706,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
- Sample(fraction, withReplacement, seed, logicalPlan)
+ Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}
/**
@@ -721,6 +721,42 @@ class DataFrame private[sql](
}
/**
+ * Randomly splits this [[DataFrame]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ * @group dfops
+ */
+ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights.sliding(2).map { x =>
+ new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
+ }.toArray
+ }
+
+ /**
+ * Randomly splits this [[DataFrame]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @group dfops
+ */
+ def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+ randomSplit(weights, Utils.random.nextLong)
+ }
+
+ /**
+ * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ * @group dfops
+ */
+ def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+ randomSplit(weights.toArray, seed)
+ }
+
+ /**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index af58911cc0..326e8ce4ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
- case logical.Sample(fraction, withReplacement, seed, child) =>
- execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
+ case logical.Sample(lb, ub, withReplacement, seed, child) =>
+ execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 1afdb40941..5ca11e67a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
/**
* :: DeveloperApi ::
+ * Sample the dataset.
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the QueryPlan
*/
@DeveloperApi
-case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
+case class Sample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output
// TODO: How to pick seed?
override def execute(): RDD[Row] = {
- child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
+ if (withReplacement) {
+ child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+ } else {
+ child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5ec06d448e..b70e127b4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}
+ test("randomSplit") {
+ val n = 600
+ val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
+ for (seed <- 1 to 5) {
+ val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
+ assert(splits.length == 3, "wrong number of splits")
+
+ assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
+ data.collect().toList, "incomplete or wrong split")
+
+ val s = splits.map(_.count())
+ assert(math.abs(s(0) - 100) < 50) // std = 9.13
+ assert(math.abs(s(1) - 200) < 50) // std = 11.55
+ assert(math.abs(s(2) - 300) < 50) // std = 12.25
+ }
+ }
+
test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2dc6463aba..0a86519e14 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 100]")
- Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
+ Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
relation)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
- Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :