diff options
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala | 9 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 39 |
2 files changed, 30 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala index ecd3b16598..534edac56b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} import org.apache.spark.rdd.RDD @@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization predict(SerDe.asTupleRDD(userAndProducts.rdd)) def getUserFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + SerDe.fromTuple2RDD(userFeatures.map { + case (user, feature) => (user, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) } def getProductFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + SerDe.fromTuple2RDD(productFeatures.map { + case (product, feature) => (product, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ab15f0f36a..f976d2f97b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,7 +28,6 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -40,15 +39,15 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -279,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)). map(_.asInstanceOf[Object]).asJava } @@ -335,7 +334,7 @@ private[python] class PythonMLLibAPI extends Serializable { mu += model.gaussians(i).mu sigma += model.gaussians(i).sigma } - List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava } finally { data.rdd.unpersist(blocking = false) } @@ -346,20 +345,20 @@ private[python] class PythonMLLibAPI extends Serializable { */ def predictSoftGMM( data: JavaRDD[Vector], - wt: Object, + wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Array[Double]] = { + si: Array[Object]): RDD[Vector] = { - val weight = wt.asInstanceOf[Array[Double]] + val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) } val model = new GaussianMixtureModel(weight, gaussians) - model.predictSoft(data) + model.predictSoft(data).map(Vectors.dense) } - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -936,6 +935,14 @@ private[spark] object SerDe extends Serializable { out.write(code) } + protected def getBytes(obj: Object): Array[Byte] = { + if (obj.getClass.isArray) { + obj.asInstanceOf[Array[Byte]] + } else { + obj.asInstanceOf[String].getBytes(LATIN1) + } + } + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) } @@ -961,7 +968,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(0)) val bb = ByteBuffer.wrap(bytes, 0, bytes.length) bb.order(ByteOrder.nativeOrder()) val db = bb.asDoubleBuffer() @@ -994,7 +1001,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 3) { throw new PickleException("should be 3") } - val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() @@ -1031,8 +1038,8 @@ private[spark] object SerDe extends Serializable { throw new PickleException("should be 3") } val size = args(0).asInstanceOf[Int] - val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) - val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val indiceBytes = getBytes(args(1)) + val valueBytes = getBytes(args(2)) val n = indiceBytes.length / 4 val indices = new Array[Int](n) val values = new Array[Double](n) |