aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorjbencook <jbenjamincook@gmail.com>2014-12-23 17:46:24 -0800
committerJosh Rosen <joshrosen@databricks.com>2014-12-23 17:46:24 -0800
commitfd41eb9574280b5cfee9b94b4f92e4c44363fb14 (patch)
tree1637c2f5ee28a4f0990093c934ac4aa1ced81167 /python
parent7e2deb71c4239564631b19c748e95c3d1aa1c77d (diff)
downloadspark-fd41eb9574280b5cfee9b94b4f92e4c44363fb14.tar.gz
spark-fd41eb9574280b5cfee9b94b4f92e4c44363fb14.tar.bz2
spark-fd41eb9574280b5cfee9b94b4f92e4c44363fb14.zip
[SPARK-4860][pyspark][sql] speeding up `sample()` and `takeSample()`
This PR modifies the python `SchemaRDD` to use `sample()` and `takeSample()` from Scala instead of the slower python implementations from `rdd.py`. This is worthwhile because the `Row`'s are already serialized as Java objects. In order to use the faster `takeSample()`, a `takeSampleToPython()` method was implemented in `SchemaRDD.scala` following the pattern of `collectToPython()`. Author: jbencook <jbenjamincook@gmail.com> Author: J. Benjamin Cook <jbenjamincook@gmail.com> Closes #3764 from jbencook/master and squashes the following commits: 6fbc769 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing sloppy indentation for takeSampleToPython() arguments 5170da2 [J. Benjamin Cook] [SPARK-4860][pyspark][sql] fixing typo: from RDD to SchemaRDD de22f70 [jbencook] [SPARK-4860][pyspark][sql] using sample() method from JavaSchemaRDD b916442 [jbencook] [SPARK-4860][pyspark][sql] adding sample() to JavaSchemaRDD 020cbdf [jbencook] [SPARK-4860][pyspark][sql] using Scala implementations of `sample()` and `takeSample()`
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 469f82473a..9807a84a66 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2085,6 +2085,34 @@ class SchemaRDD(RDD):
else:
raise ValueError("Can only subtract another SchemaRDD")
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this SchemaRDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.sample(False, 0.5, 97).count()
+ 2L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
+ return SchemaRDD(rdd, self.sql_ctx)
+
+ def takeSample(self, withReplacement, num, seed=None):
+ """Return a fixed-size sampled subset of this SchemaRDD.
+
+ >>> srdd = sqlCtx.inferSchema(rdd)
+ >>> srdd.takeSample(False, 2, 97)
+ [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ """
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ with SCCallSiteSync(self.context) as css:
+ bytesInJava = self._jschema_rdd.baseSchemaRDD() \
+ .takeSampleToPython(withReplacement, num, long(seed)) \
+ .iterator()
+ cls = _create_cls(self.schema())
+ return map(cls, self._collect_iterator_through_file(bytesInJava))
+
def _test():
import doctest