aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql.py28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala6
3 files changed, 49 insertions, 0 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 469f82473a..9807a84a66 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2085,6 +2085,34 @@ class SchemaRDD(RDD):
else:
raise ValueError("Can only subtract another SchemaRDD")
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this SchemaRDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.sample(False, 0.5, 97).count()
+ 2L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
+ return SchemaRDD(rdd, self.sql_ctx)
+
+ def takeSample(self, withReplacement, num, seed=None):
+ """Return a fixed-size sampled subset of this SchemaRDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.takeSample(False, 2, 97)
+ [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ """
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ with SCCallSiteSync(self.context) as css:
+ bytesInJava = self._jschema_rdd.baseSchemaRDD() \
+ .takeSampleToPython(withReplacement, num, long(seed)) \
+ .iterator()
+ cls = _create_cls(self.schema())
+ return map(cls, self._collect_iterator_through_file(bytesInJava))
+
def _test():
import doctest
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 7baf8ffcef..856b10f1a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -438,6 +438,21 @@ class SchemaRDD(
}
/**
+ * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
+ * format as javaToPython and collectToPython. It is used by pyspark.
+ */
+ private[sql] def takeSampleToPython(
+ withReplacement: Boolean,
+ num: Int,
+ seed: Long): JList[Array[Byte]] = {
+ val fieldTypes = schema.fields.map(_.dataType)
+ val pickle = new Pickler
+ new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
+ EvaluatePython.rowToArray(row, fieldTypes)
+ }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
+ }
+
+ /**
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
* of base RDD functions that do not change schema.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index ac4844f9b9..5b9c612487 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -218,4 +218,10 @@ class JavaSchemaRDD(
*/
def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD =
this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD
+
+ /**
+ * Return a SchemaRDD with a sampled version of the underlying dataset.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD =
+ this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD
}