diff options
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/column.py | 12 | ||||
-rw-r--r-- | python/pyspark/sql/dataframe.py | 4 |
2 files changed, 14 insertions, 2 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8af8637cf9..0948f9b27c 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -61,6 +61,18 @@ def _to_seq(sc, cols, converter=None): return sc._jvm.PythonUtils.toSeq(cols) +def _to_list(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM (Scala) List of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toList(cols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 025811f519..e269ef4304 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -32,7 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql import since from pyspark.sql.types import _parse_datatype_json_string -from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -494,7 +494,7 @@ class DataFrame(object): if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property |