From 57d70d26c88819360cdc806e7124aa2cc1b9e4c5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 31 Jan 2017 15:42:36 -0800 Subject: [SPARK-17161][PYSPARK][ML] Add PySpark-ML JavaWrapper convenience function to create Py4J JavaArrays ## What changes were proposed in this pull request? Adding convenience function to Python `JavaWrapper` so that it is easy to create a Py4J JavaArray that is compatible with current class constructors that have a Scala `Array` as input so that it is not necessary to have a Java/Python friendly constructor. The function takes a Java class as input that is used by Py4J to create the Java array of the given class. As an example, `OneVsRest` has been updated to use this and the alternate constructor is removed. ## How was this patch tested? Added unit tests for the new convenience function and updated `OneVsRest` doctests which use this to persist the model. Author: Bryan Cutler Closes #14725 from BryanCutler/pyspark-new_java_array-CountVectorizer-SPARK-17161. --- python/pyspark/ml/wrapper.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'python/pyspark/ml/wrapper.py') diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 13b75e9919..80a0b31cd8 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -16,6 +16,9 @@ # from abc import ABCMeta, abstractmethod +import sys +if sys.version >= '3': + xrange = range from pyspark import SparkContext from pyspark.sql import DataFrame @@ -59,6 +62,32 @@ class JavaWrapper(object): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + @staticmethod + def _new_java_array(pylist, java_class): + """ + Create a Java array of given java_class type. Useful for + calling a method with a Scala Array from Python with Py4J. + + :param pylist: + Python list to convert to a Java Array. + :param java_class: + Java class to specify the type of Array. Should be in the + form of sc._gateway.jvm.* (sc is a valid Spark Context). + :return: + Java Array of converted pylist. + + Example primitive Java classes: + - basestring -> sc._gateway.jvm.java.lang.String + - int -> sc._gateway.jvm.java.lang.Integer + - float -> sc._gateway.jvm.java.lang.Double + - bool -> sc._gateway.jvm.java.lang.Boolean + """ + sc = SparkContext._active_spark_context + java_array = sc._gateway.new_array(java_class, len(pylist)) + for i in xrange(len(pylist)): + java_array[i] = pylist[i] + return java_array + @inherit_doc class JavaParams(JavaWrapper, Params): -- cgit v1.2.3