aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala84
-rw-r--r--python/pyspark/mllib/classification.py30
-rw-r--r--python/pyspark/mllib/clustering.py15
-rw-r--r--python/pyspark/mllib/common.py135
-rw-r--r--python/pyspark/mllib/feature.py122
-rw-r--r--python/pyspark/mllib/linalg.py12
-rw-r--r--python/pyspark/mllib/random.py34
-rw-r--r--python/pyspark/mllib/recommendation.py62
-rw-r--r--python/pyspark/mllib/regression.py52
-rw-r--r--python/pyspark/mllib/stat.py65
-rw-r--r--python/pyspark/mllib/tree.py55
-rw-r--r--python/pyspark/mllib/util.py7
12 files changed, 287 insertions, 386 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 485abe2723..acdc67ddc6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
-import java.util.{ArrayList => JArrayList}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -72,15 +72,11 @@ class PythonMLLibAPI extends Serializable {
private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
- initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
+ initialWeights: Vector): JList[Object] = {
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
learner.disableUncachedWarning()
val model = learner.run(data.rdd, initialWeights)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.dumps(model.weights))
- ret.add(model.intercept: java.lang.Double)
- ret
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}
/**
@@ -91,10 +87,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
@@ -113,7 +109,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lrAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -125,7 +121,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val lassoAlg = new LassoWithSGD()
lassoAlg.optimizer
.setNumIterations(numIterations)
@@ -135,7 +131,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lassoAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -147,7 +143,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
ridgeAlg.optimizer
.setNumIterations(numIterations)
@@ -157,7 +153,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
ridgeAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -169,9 +165,9 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
@@ -190,7 +186,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
SVMAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -201,10 +197,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
@@ -223,7 +219,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
LogRegAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -231,13 +227,10 @@ class PythonMLLibAPI extends Serializable {
*/
def trainNaiveBayes(
data: JavaRDD[LabeledPoint],
- lambda: Double): java.util.List[java.lang.Object] = {
+ lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(Vectors.dense(model.labels))
- ret.add(Vectors.dense(model.pi))
- ret.add(model.theta)
- ret
+ List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
+ map(_.asInstanceOf[Object]).asJava
}
/**
@@ -260,6 +253,21 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
+ */
+ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel)
+ extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) {
+
+ def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
+ predict(SerDe.asTupleRDD(userAndProducts.rdd))
+
+ def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ }
+
+ /**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
* needs to be taken in the Python code to ensure it gets freed on exit; see
@@ -271,7 +279,7 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
- ALS.train(ratings.rdd, rank, iterations, lambda, blocks)
+ new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks))
}
/**
@@ -287,7 +295,8 @@ class PythonMLLibAPI extends Serializable {
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
- ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
+ new MatrixFactorizationModelWrapper(
+ ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha))
}
/**
@@ -373,19 +382,16 @@ class PythonMLLibAPI extends Serializable {
rdd.rdd.map(model.transform)
}
- def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
}
- def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(words)
- ret.add(similarity)
- ret
+ List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
}
@@ -395,13 +401,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param data Training data
- * @param categoricalFeaturesInfoJMap Categorical features info, as Java map
+ * @param categoricalFeaturesInfo Categorical features info, as Java map
*/
def trainDecisionTreeModel(
data: JavaRDD[LabeledPoint],
algoStr: String,
numClasses: Int,
- categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
+ categoricalFeaturesInfo: JMap[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int,
@@ -417,7 +423,7 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
@@ -589,7 +595,7 @@ private[spark] object SerDe extends Serializable {
if (objects.length == 0 || objects.length > 3) {
out.write(Opcodes.MARK)
}
- objects.foreach(pickler.save(_))
+ objects.foreach(pickler.save)
val code = objects.length match {
case 1 => Opcodes.TUPLE1
case 2 => Opcodes.TUPLE2
@@ -719,7 +725,7 @@ private[spark] object SerDe extends Serializable {
}
/* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
- def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = {
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
rdd.map(x => Array(x._1, x._2))
}
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index e295c9d095..297a2bf37d 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -20,8 +20,8 @@ from math import exp
import numpy
from numpy import array
-from pyspark import SparkContext, PickleSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
@@ -102,14 +102,11 @@ class LogisticRegressionWithSGD(object):
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step,
+ miniBatchFraction, i, regParam, regType, intercept)
- def train(jdata, i):
- return sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
- jdata, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
-
- return _regression_train_wrapper(sc, train, LogisticRegressionModel, data,
- initialWeights)
+ return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
class SVMModel(LinearModel):
@@ -174,13 +171,11 @@ class SVMWithSGD(object):
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
+ def train(rdd, i):
+ return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i, regType, intercept)
- return _regression_train_wrapper(sc, train, SVMModel, data, initialWeights)
+ return _regression_train_wrapper(train, SVMModel, data, initialWeights)
class NaiveBayesModel(object):
@@ -243,14 +238,13 @@ class NaiveBayes(object):
(e.g. a count vector).
:param lambda_: The smoothing parameter
"""
- sc = data.context
- jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_)
- labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist)))
+ labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_)
return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
def _test():
import doctest
+ from pyspark import SparkContext
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 5ee7997104..fe4c4cc509 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -16,8 +16,8 @@
#
from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['KMeansModel', 'KMeans']
@@ -80,14 +80,11 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
"""Train a k-means clustering model."""
- sc = rdd.context
- ser = PickleSerializer()
# cache serialized data to avoid objects over head in JVM
- cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache()
- model = sc._jvm.PythonMLLibAPI().trainKMeansModel(
- _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode)
- bytes = sc._jvm.SerDe.dumps(model.clusterCenters())
- centers = ser.loads(str(bytes))
+ jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True)
+ model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs,
+ initializationMode)
+ centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
new file mode 100644
index 0000000000..76864d8163
--- /dev/null
+++ b/python/pyspark/mllib/common.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+ 'nan': 'NaN',
+ 'inf': 'Infinity',
+ '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+ if isinstance(obj, float):
+ s = unicode(obj)
+ return _float_str_mapping.get(s, s)
+ return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+ 'LinkedList',
+ 'SparseVector',
+ 'DenseVector',
+ 'DenseMatrix',
+ 'Rating',
+ 'LabeledPoint',
+]
+
+
+# this will call the MLlib version of pythonToJava()
+def _to_java_object_rdd(rdd, cache=False):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ if cache:
+ rdd.cache()
+ return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, RDD):
+ obj = _to_java_object_rdd(obj)
+ elif isinstance(obj, SparkContext):
+ obj = obj._jsc
+ elif isinstance(obj, dict):
+ obj = MapConverter().convert(obj, sc._gateway._gateway_client)
+ elif isinstance(obj, (list, tuple)):
+ obj = ListConverter().convert(obj, sc._gateway._gateway_client)
+ elif isinstance(obj, JavaObject):
+ pass
+ elif isinstance(obj, (int, long, float, bool, basestring)):
+ pass
+ else:
+ bytes = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.SerDe.loads(bytes)
+ return obj
+
+
+def _java2py(sc, r):
+ if isinstance(r, JavaObject):
+ clsName = r.getClass().getSimpleName()
+ # convert RDD into JavaRDD
+ if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ r = r.toJavaRDD()
+ clsName = 'JavaRDD'
+
+ if clsName == 'JavaRDD':
+ jrdd = sc._jvm.SerDe.javaToPython(r)
+ return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer()))
+
+ elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes:
+ r = sc._jvm.SerDe.dumps(r)
+
+ if isinstance(r, bytearray):
+ r = PickleSerializer().loads(str(r))
+ return r
+
+
+def callJavaFunc(sc, func, *args):
+ """ Call Java Function """
+ args = [_py2java(sc, a) for a in args]
+ return _java2py(sc, func(*args))
+
+
+def callMLlibFunc(name, *args):
+ """ Call API in PythonMLLibAPI """
+ sc = SparkContext._active_spark_context
+ api = getattr(sc._jvm.PythonMLLibAPI(), name)
+ return callJavaFunc(sc, api, *args)
+
+
+class JavaModelWrapper(object):
+ """
+ Wrapper for the model in JVM
+ """
+ def __init__(self, java_model):
+ self._sc = SparkContext._active_spark_context
+ self._java_model = java_model
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_model)
+
+ def call(self, name, *a):
+ """Call method of java_model"""
+ return callJavaFunc(self._sc, getattr(self._java_model, name), *a)
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 324343443e..44bf6f269d 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -21,89 +21,16 @@ Python package for feature in MLlib.
import sys
import warnings
-import py4j.protocol
from py4j.protocol import Py4JJavaError
-from py4j.java_gateway import JavaObject
from pyspark import RDD, SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import Vectors, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import Vectors
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
-# Hack for support float('inf') in Py4j
-_old_smart_decode = py4j.protocol.smart_decode
-
-_float_str_mapping = {
- u'nan': u'NaN',
- u'inf': u'Infinity',
- u'-inf': u'-Infinity',
-}
-
-
-def _new_smart_decode(obj):
- if isinstance(obj, float):
- s = unicode(obj)
- return _float_str_mapping.get(s, s)
- return _old_smart_decode(obj)
-
-py4j.protocol.smart_decode = _new_smart_decode
-
-
-# TODO: move these helper functions into utils
-_picklable_classes = [
- 'LinkedList',
- 'SparseVector',
- 'DenseVector',
- 'DenseMatrix',
- 'Rating',
- 'LabeledPoint',
-]
-
-
-def _py2java(sc, a):
- """ Convert Python object into Java """
- if isinstance(a, RDD):
- a = _to_java_object_rdd(a)
- elif not isinstance(a, (int, long, float, bool, basestring)):
- bytes = bytearray(PickleSerializer().dumps(a))
- a = sc._jvm.SerDe.loads(bytes)
- return a
-
-
-def _java2py(sc, r):
- if isinstance(r, JavaObject):
- clsName = r.getClass().getSimpleName()
- if clsName in ("RDD", "JavaRDD"):
- if clsName == "RDD":
- r = r.toJavaRDD()
- jrdd = sc._jvm.SerDe.javaToPython(r)
- return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer()))
-
- elif clsName in _picklable_classes:
- r = sc._jvm.SerDe.dumps(r)
-
- if isinstance(r, bytearray):
- r = PickleSerializer().loads(str(r))
- return r
-
-
-def _callJavaFunc(sc, func, *args):
- """ Call Java Function
- """
- args = [_py2java(sc, a) for a in args]
- return _java2py(sc, func(*args))
-
-
-def _callAPI(sc, name, *args):
- """ Call API in PythonMLLibAPI
- """
- api = getattr(sc._jvm.PythonMLLibAPI(), name)
- return _callJavaFunc(sc, api, *args)
-
-
class VectorTransformer(object):
"""
:: DeveloperApi ::
@@ -160,25 +87,19 @@ class Normalizer(VectorTransformer):
"""
sc = SparkContext._active_spark_context
assert sc is not None, "SparkContext should be initialized first"
- return _callAPI(sc, "normalizeVector", self.p, vector)
+ return callMLlibFunc("normalizeVector", self.p, vector)
-class JavaModelWrapper(VectorTransformer):
+class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
"""
Wrapper for the model in JVM
"""
- def __init__(self, sc, java_model):
- self._sc = sc
- self._java_model = java_model
-
- def __del__(self):
- self._sc._gateway.detach(self._java_model)
def transform(self, dataset):
- return _callJavaFunc(self._sc, self._java_model.transform, dataset)
+ return self.call("transform", dataset)
-class StandardScalerModel(JavaModelWrapper):
+class StandardScalerModel(JavaVectorTransformer):
"""
:: Experimental ::
@@ -192,7 +113,7 @@ class StandardScalerModel(JavaModelWrapper):
:return: Standardized vector. If the variance of a column is zero,
it will return default `0.0` for the column with zero variance.
"""
- return JavaModelWrapper.transform(self, vector)
+ return JavaVectorTransformer.transform(self, vector)
class StandardScaler(object):
@@ -233,9 +154,8 @@ class StandardScaler(object):
the transformation model.
:return: a StandardScalarModel
"""
- sc = dataset.context
- jmodel = _callAPI(sc, "fitStandardScaler", self.withMean, self.withStd, dataset)
- return StandardScalerModel(sc, jmodel)
+ jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset)
+ return StandardScalerModel(jmodel)
class HashingTF(object):
@@ -276,7 +196,7 @@ class HashingTF(object):
return Vectors.sparse(self.numFeatures, freq.items())
-class IDFModel(JavaModelWrapper):
+class IDFModel(JavaVectorTransformer):
"""
Represents an IDF model that can transform term frequency vectors.
"""
@@ -291,7 +211,7 @@ class IDFModel(JavaModelWrapper):
:param dataset: an RDD of term frequency vectors
:return: an RDD of TF-IDF vectors
"""
- return JavaModelWrapper.transform(self, dataset)
+ return JavaVectorTransformer.transform(self, dataset)
class IDF(object):
@@ -335,12 +255,11 @@ class IDF(object):
:param dataset: an RDD of term frequency vectors
"""
- sc = dataset.context
- jmodel = _callAPI(sc, "fitIDF", self.minDocFreq, dataset)
- return IDFModel(sc, jmodel)
+ jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset)
+ return IDFModel(jmodel)
-class Word2VecModel(JavaModelWrapper):
+class Word2VecModel(JavaVectorTransformer):
"""
class for Word2Vec model
"""
@@ -354,7 +273,7 @@ class Word2VecModel(JavaModelWrapper):
:return: vector representation of word(s)
"""
try:
- return _callJavaFunc(self._sc, self._java_model.transform, word)
+ return self.call("transform", word)
except Py4JJavaError:
raise ValueError("%s not found" % word)
@@ -368,7 +287,7 @@ class Word2VecModel(JavaModelWrapper):
Note: local use only
"""
- words, similarity = _callJavaFunc(self._sc, self._java_model.findSynonyms, word, num)
+ words, similarity = self.call("findSynonyms", word, num)
return zip(words, similarity)
@@ -458,11 +377,10 @@ class Word2Vec(object):
:param data: training data. RDD of subtype of Iterable[String]
:return: Word2VecModel instance
"""
- sc = data.context
- jmodel = _callAPI(sc, "trainWord2Vec", data, int(self.vectorSize),
- float(self.learningRate), int(self.numPartitions),
- int(self.numIterations), long(self.seed))
- return Word2VecModel(sc, jmodel)
+ jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
+ float(self.learningRate), int(self.numPartitions),
+ int(self.numIterations), long(self.seed))
+ return Word2VecModel(jmodel)
def _test():
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 1b9bf59624..d0a0e102a1 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,7 +29,6 @@ import copy_reg
import numpy as np
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
@@ -52,17 +51,6 @@ except:
_have_scipy = False
-# this will call the MLlib version of pythonToJava()
-def _to_java_object_rdd(rdd):
- """ Return an JavaRDD of Object by unpickling
-
- It will convert each Python object into Java object by Pyrolite, whenever the
- RDD is serialized in batch or not.
- """
- rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
- return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
-
-
def _convert_to_vector(l):
if isinstance(l, Vector):
return l
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
index 2202c51ab9..7eebfc6bcd 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/random.py
@@ -21,22 +21,12 @@ Python package for random data generation.
from functools import wraps
-from pyspark.rdd import RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
+from pyspark.mllib.common import callMLlibFunc
__all__ = ['RandomRDDs', ]
-def serialize(f):
- @wraps(f)
- def func(sc, *a, **kw):
- jrdd = f(sc, *a, **kw)
- return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc,
- BatchedSerializer(PickleSerializer(), 1024))
- return func
-
-
def toArray(f):
@wraps(f)
def func(sc, *a, **kw):
@@ -52,7 +42,6 @@ class RandomRDDs(object):
"""
@staticmethod
- @serialize
def uniformRDD(sc, size, numPartitions=None, seed=None):
"""
Generates an RDD comprised of i.i.d. samples from the
@@ -74,10 +63,9 @@ class RandomRDDs(object):
>>> parts == sc.defaultParallelism
True
"""
- return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed)
+ return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed)
@staticmethod
- @serialize
def normalRDD(sc, size, numPartitions=None, seed=None):
"""
Generates an RDD comprised of i.i.d. samples from the standard normal
@@ -97,10 +85,9 @@ class RandomRDDs(object):
>>> abs(stats.stdev() - 1.0) < 0.1
True
"""
- return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed)
+ return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed)
@staticmethod
- @serialize
def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
"""
Generates an RDD comprised of i.i.d. samples from the Poisson
@@ -117,11 +104,10 @@ class RandomRDDs(object):
>>> abs(stats.stdev() - sqrt(mean)) < 0.5
True
"""
- return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed)
+ return callMLlibFunc("poissonRDD", sc._jsc, mean, size, numPartitions, seed)
@staticmethod
@toArray
- @serialize
def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
"""
Generates an RDD comprised of vectors containing i.i.d. samples drawn
@@ -136,12 +122,10 @@ class RandomRDDs(object):
>>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions()
4
"""
- return sc._jvm.PythonMLLibAPI() \
- .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed)
+ return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
@staticmethod
@toArray
- @serialize
def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
"""
Generates an RDD comprised of vectors containing i.i.d. samples drawn
@@ -156,12 +140,10 @@ class RandomRDDs(object):
>>> abs(mat.std() - 1.0) < 0.1
True
"""
- return sc._jvm.PythonMLLibAPI() \
- .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed)
+ return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
@staticmethod
@toArray
- @serialize
def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
"""
Generates an RDD comprised of vectors containing i.i.d. samples drawn
@@ -179,8 +161,8 @@ class RandomRDDs(object):
>>> abs(mat.std() - sqrt(mean)) < 0.5
True
"""
- return sc._jvm.PythonMLLibAPI() \
- .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed)
+ return callMLlibFunc("poissonVectorRDD", sc._jsc, mean, numRows, numCols,
+ numPartitions, seed)
def _test():
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 22872dbbe3..6b32af07c9 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -16,9 +16,8 @@
#
from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.rdd import RDD
-from pyspark.mllib.linalg import _to_java_object_rdd
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
__all__ = ['MatrixFactorizationModel', 'ALS']
@@ -36,7 +35,7 @@ class Rating(object):
return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating)
-class MatrixFactorizationModel(object):
+class MatrixFactorizationModel(JavaModelWrapper):
"""A matrix factorisation model trained by regularized alternating
least-squares.
@@ -71,48 +70,21 @@ class MatrixFactorizationModel(object):
>>> len(latents) == 4
True
"""
-
- def __init__(self, sc, java_model):
- self._context = sc
- self._java_model = java_model
-
- def __del__(self):
- self._context._gateway.detach(self._java_model)
-
def predict(self, user, product):
return self._java_model.predict(user, product)
def predictAll(self, user_product):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
first = user_product.first()
- if isinstance(first, list):
- user_product = user_product.map(tuple)
- first = tuple(first)
- assert type(first) is tuple and len(first) == 2, \
- "user_product should be RDD of (user, product)"
- if any(isinstance(x, str) for x in first):
- user_product = user_product.map(lambda (u, p): (int(x), int(p)))
- first = tuple(map(int, first))
- assert all(type(x) is int for x in first), "user and product in user_product shoul be int"
- sc = self._context
- tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd())
- jresult = self._java_model.predict(tuplerdd).toJavaRDD()
- return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
- AutoBatchedSerializer(PickleSerializer()))
+ assert len(first) == 2, "user_product should be RDD of (user, product)"
+ user_product = user_product.map(lambda (u, p): (int(u), int(p)))
+ return self.call("predict", user_product)
def userFeatures(self):
- sc = self._context
- juf = self._java_model.userFeatures()
- juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
- return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
- AutoBatchedSerializer(PickleSerializer()))
+ return self.call("getUserFeatures")
def productFeatures(self):
- sc = self._context
- jpf = self._java_model.productFeatures()
- jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
- return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
- AutoBatchedSerializer(PickleSerializer()))
+ return self.call("getProductFeatures")
class ALS(object):
@@ -126,25 +98,19 @@ class ALS(object):
ratings = ratings.map(lambda x: Rating(*x))
else:
raise ValueError("rating should be RDD of Rating or tuple/list")
- # serialize them by AutoBatchedSerializer before cache to reduce the
- # objects overhead in JVM
- cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache()
- return _to_java_object_rdd(cached)
+ return _to_java_object_rdd(ratings, True)
@classmethod
def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
- sc = ratings.context
- jrating = cls._prepare(ratings)
- mod = sc._jvm.PythonMLLibAPI().trainALSModel(jrating, rank, iterations, lambda_, blocks)
- return MatrixFactorizationModel(sc, mod)
+ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations,
+ lambda_, blocks)
+ return MatrixFactorizationModel(model)
@classmethod
def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
- sc = ratings.context
- jrating = cls._prepare(ratings)
- mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(
- jrating, rank, iterations, lambda_, blocks, alpha)
- return MatrixFactorizationModel(sc, mod)
+ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank,
+ iterations, lambda_, blocks, alpha)
+ return MatrixFactorizationModel(model)
def _test():
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 93e17faf5c..43c1a2fc10 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,9 +18,8 @@
import numpy as np
from numpy import array
-from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD']
@@ -124,17 +123,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
# train_func should take two parameters, namely data and initial_weights, and
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
-def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights):
+def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
initial_weights = initial_weights or [0.0] * len(data.first().features)
- ser = PickleSerializer()
- initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights)))
- # use AutoBatchedSerializer before cache to reduce the memory
- # overhead in JVM
- cached = data._reserialize(AutoBatchedSerializer(ser)).cache()
- ans = train_func(_to_java_object_rdd(cached), initial_bytes)
- assert len(ans) == 2, "JVM call result had unexpected length"
- weights = ser.loads(str(ans[0]))
- return modelClass(weights, ans[1])
+ weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
+ _convert_to_vector(initial_weights))
+ return modelClass(weights, intercept)
class LinearRegressionWithSGD(object):
@@ -168,13 +161,12 @@ class LinearRegressionWithSGD(object):
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step,
+ miniBatchFraction, i, regParam, regType, intercept)
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
-
- return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, LinearRegressionModel,
+ data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -216,12 +208,10 @@ class LassoWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights)
+ def train(rdd, i):
+ return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i)
+ return _regression_train_wrapper(train, LassoModel, data, initialWeights)
class RidgeRegressionModel(LinearRegressionModelBase):
@@ -263,17 +253,17 @@ class RidgeRegressionWithSGD(object):
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
+ def train(rdd, i):
+ return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam,
+ miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, RidgeRegressionModel,
+ data, initialWeights)
def _test():
import doctest
+ from pyspark import SparkContext
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index 84baf12b90..15f0652f83 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -19,66 +19,36 @@
Python package for statistical functions in MLlib.
"""
-from functools import wraps
-
-from pyspark import PickleSerializer
-from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import _convert_to_vector
__all__ = ['MultivariateStatisticalSummary', 'Statistics']
-def serialize(f):
- ser = PickleSerializer()
-
- @wraps(f)
- def func(self):
- jvec = f(self)
- bytes = self._sc._jvm.SerDe.dumps(jvec)
- return ser.loads(str(bytes)).toArray()
-
- return func
-
-
-class MultivariateStatisticalSummary(object):
+class MultivariateStatisticalSummary(JavaModelWrapper):
"""
Trait for multivariate statistical summary of a data matrix.
"""
- def __init__(self, sc, java_summary):
- """
- :param sc: Spark context
- :param java_summary: Handle to Java summary object
- """
- self._sc = sc
- self._java_summary = java_summary
-
- def __del__(self):
- self._sc._gateway.detach(self._java_summary)
-
- @serialize
def mean(self):
- return self._java_summary.mean()
+ return self.call("mean").toArray()
- @serialize
def variance(self):
- return self._java_summary.variance()
+ return self.call("variance").toArray()
def count(self):
- return self._java_summary.count()
+ return self.call("count")
- @serialize
def numNonzeros(self):
- return self._java_summary.numNonzeros()
+ return self.call("numNonzeros").toArray()
- @serialize
def max(self):
- return self._java_summary.max()
+ return self.call("max").toArray()
- @serialize
def min(self):
- return self._java_summary.min()
+ return self.call("min").toArray()
class Statistics(object):
@@ -106,10 +76,8 @@ class Statistics(object):
>>> cStats.min()
array([ 2., 0., 0., -2.])
"""
- sc = rdd.ctx
- jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
- cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
- return MultivariateStatisticalSummary(sc, cStats)
+ cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector))
+ return MultivariateStatisticalSummary(cStats)
@staticmethod
def corr(x, y=None, method=None):
@@ -156,7 +124,6 @@ class Statistics(object):
... except TypeError:
... pass
"""
- sc = x.ctx
# Check inputs to determine whether a single value or a matrix is needed for output.
# Since it's legal for users to use the method name as the second argument, we need to
# check if y is used to specify the method name instead.
@@ -164,15 +131,9 @@ class Statistics(object):
raise TypeError("Use 'method=' to specify method name.")
if not y:
- jx = _to_java_object_rdd(x.map(_convert_to_vector))
- resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
- bytes = sc._jvm.SerDe.dumps(resultMat)
- ser = PickleSerializer()
- return ser.loads(str(bytes)).toArray()
+ return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray()
else:
- jx = _to_java_object_rdd(x.map(float))
- jy = _to_java_object_rdd(y.map(float))
- return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
+ return callMLlibFunc("corr", x.map(float), y.map(float), method)
def _test():
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 64ee79d83e..5d1a3c0962 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -15,36 +15,22 @@
# limitations under the License.
#
-from py4j.java_collections import MapConverter
-
from pyspark import SparkContext, RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
-from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
__all__ = ['DecisionTreeModel', 'DecisionTree']
-class DecisionTreeModel(object):
+class DecisionTreeModel(JavaModelWrapper):
"""
A decision tree model for classification or regression.
EXPERIMENTAL: This is an experimental API.
- It will probably be modified for Spark v1.2.
+ It will probably be modified in future.
"""
-
- def __init__(self, sc, java_model):
- """
- :param sc: Spark context
- :param java_model: Handle to Java model object
- """
- self._sc = sc
- self._java_model = java_model
-
- def __del__(self):
- self._sc._gateway.detach(self._java_model)
-
def predict(self, x):
"""
Predict the label of one or more examples.
@@ -52,24 +38,11 @@ class DecisionTreeModel(object):
:param x: Data point (feature vector),
or an RDD of data points (feature vectors).
"""
- SerDe = self._sc._jvm.SerDe
- ser = PickleSerializer()
if isinstance(x, RDD):
- # Bulk prediction
- first = x.take(1)
- if not first:
- return self._sc.parallelize([])
- if not isinstance(first[0], Vector):
- x = x.map(_convert_to_vector)
- jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD()
- jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred)
- return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024))
+ return self.call("predict", x.map(_convert_to_vector))
else:
- # Assume x is a single data point.
- bytes = bytearray(ser.dumps(_convert_to_vector(x)))
- vec = self._sc._jvm.SerDe.loads(bytes)
- return self._java_model.predict(vec)
+ return self.call("predict", _convert_to_vector(x))
def numNodes(self):
return self._java_model.numNodes()
@@ -98,19 +71,13 @@ class DecisionTree(object):
"""
@staticmethod
- def _train(data, type, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
- minInfoGain=0.0):
+ def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32,
+ minInstancesPerNode=1, minInfoGain=0.0):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
- sc = data.context
- jrdd = _to_java_object_rdd(data)
- cfiMap = MapConverter().convert(categoricalFeaturesInfo,
- sc._gateway._gateway_client)
- model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
- jrdd, type, numClasses, cfiMap,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
- return DecisionTreeModel(sc, model)
+ model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ return DecisionTreeModel(model)
@staticmethod
def trainClassifier(data, numClasses, categoricalFeaturesInfo,
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 84b39a4861..96aef8f510 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -18,8 +18,7 @@
import numpy as np
import warnings
-from pyspark.rdd import RDD
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
@@ -173,9 +172,7 @@ class MLUtils(object):
(0.0,[1.01,2.02,3.03])
"""
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
- jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions)
- jpyrdd = sc._jvm.SerDe.javaToPython(jrdd)
- return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer()))
+ return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
def _test():