aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-12 10:48:13 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-12 10:48:13 -0800
commit99bd5006650bb15ec5465ffee1ebaca81354a3df (patch)
tree1029c3456b548c908eb8ff1cb3e012b8b137bb9d /mllib
parent466b1f671b21f575d28f9c103f51765790914fe3 (diff)
downloadspark-99bd5006650bb15ec5465ffee1ebaca81354a3df.tar.gz
spark-99bd5006650bb15ec5465ffee1ebaca81354a3df.tar.bz2
spark-99bd5006650bb15ec5465ffee1ebaca81354a3df.zip
[SPARK-5757][MLLIB] replace SQL JSON usage in model import/export by json4s
This PR detaches MLlib model import/export code from SQL's JSON support, and hence unblocks #4544 . yhuai Author: Xiangrui Meng <meng@databricks.com> Closes #4555 from mengxr/SPARK-5757 and squashes the following commits: b0415e8 [Xiangrui Meng] replace SQL JSON usage by json4s
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala51
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala25
15 files changed, 92 insertions, 127 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 348c1e8760..35a0db76f3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.classification
+import org.json4s.{DefaultFormats, JValue}
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
/**
* :: Experimental ::
@@ -60,16 +60,10 @@ private[mllib] object ClassificationModel {
/**
* Helper method for loading GLM classification model metadata.
- *
- * @param modelClass String name for model class (used for error messages)
* @return (numFeatures, numClasses)
*/
- def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
- metadata.select("numFeatures", "numClasses").take(1)(0) match {
- case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
- case _ => throw new Exception(s"$modelClass unable to load" +
- s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
- }
+ def getNumFeaturesClasses(metadata: JValue): (Int, Int) = {
+ implicit val formats = DefaultFormats
+ ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int])
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 9a391bfff7..420d6e2861 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -173,8 +173,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val (numFeatures, numClasses) =
- ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
// numFeatures, numClasses, weights are checked in model initialization
val model =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index d9ce2822dd..f9142bc226 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -18,15 +18,16 @@
package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{SparkContext, SparkException, Logging}
+import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
-
/**
* Model for Naive Bayes Classifiers.
*
@@ -78,7 +79,7 @@ class NaiveBayesModel private[mllib] (
object NaiveBayesModel extends Loader[NaiveBayesModel] {
- import Loader._
+ import org.apache.spark.mllib.util.Loader._
private object SaveLoadV1_0 {
@@ -95,10 +96,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD =
- sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
- .toDataFrame("class", "version", "numFeatures", "numClasses")
- metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
// Create Parquet data.
val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
@@ -126,8 +127,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val (numFeatures, numClasses) =
- ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 24d31e62ba..cfc7f868a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -23,10 +23,9 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
+import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD
-
/**
* Model for Support Vector Machines (SVMs).
*
@@ -97,8 +96,7 @@ object SVMModel extends Loader[SVMModel] {
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val (numFeatures, numClasses) =
- ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 8d600572ed..1d118963b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.classification.impl
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLContext}
/**
* Helper class for import/export of GLM classification models.
@@ -52,16 +55,14 @@ private[classification] object GLMClassificationModel {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD =
- sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
- .toDataFrame("class", "version", "numFeatures", "numClasses")
- metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val data = Data(weights, intercept, threshold)
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
- // TODO: repartition with 1 partition after SPARK-5532 gets fixed
- dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ sc.parallelize(Seq(data), 1).saveAsParquetFile(Loader.dataPath(path))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 16979c9ed4..a3a3b5d418 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -22,6 +22,9 @@ import java.lang.{Integer => JavaInteger}
import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
@@ -153,7 +156,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
import org.apache.spark.mllib.util.Loader._
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
- val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
+ val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, formatVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
@@ -181,19 +184,20 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
- val metadata = (thisClassName, thisFormatVersion, model.rank)
- val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
- metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
}
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+ implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
- val rank = metadata.select("rank").first().getInt(0)
+ val rank = (metadata \ "rank").extract[Int]
val userFeatures = sqlContext.parquetFile(userPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 1159e59fff..e8b0381657 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -58,7 +58,7 @@ object LassoModel extends Loader[LassoModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LassoModel(data.weights, data.intercept)
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 0136dcfdce..6fa7ad52a5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -58,7 +58,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LinearRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 843e59bdfb..214ac4d0ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.regression
+import org.json4s.{DefaultFormats, JValue}
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
@Experimental
trait RegressionModel extends Serializable {
@@ -55,16 +55,10 @@ private[mllib] object RegressionModel {
/**
* Helper method for loading GLM regression model metadata.
- *
- * @param modelClass String name for model class (used for error messages)
* @return numFeatures
*/
- def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
- metadata.select("numFeatures").take(1)(0) match {
- case Row(nFeatures: Int) => nFeatures
- case _ => throw new Exception(s"$modelClass unable to load" +
- s" numFeatures from metadata: ${Loader.metadataPath(path)}")
- }
+ def getNumFeatures(metadata: JValue): Int = {
+ implicit val formats = DefaultFormats
+ (metadata \ "numFeatures").extract[Int]
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index f2a5f1db1e..8838ca8c14 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -59,7 +59,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new RidgeRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index 838100e949..f75de6f637 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -17,6 +17,9 @@
package org.apache.spark.mllib.regression.impl
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
@@ -48,10 +51,10 @@ private[regression] object GLMRegressionModel {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD =
- sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
- .toDataFrame("class", "version", "numFeatures")
- metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> weights.size)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val data = Data(weights, intercept)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b3e8ed9af8..9a586b9d9c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,14 +17,13 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
@@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl._
-import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
-import org.apache.spark.SparkContext._
-
/**
* :: Experimental ::
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 89ecf3773d..373192a20c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -19,6 +19,10 @@ package org.apache.spark.mllib.tree.model
import scala.collection.mutable
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
@@ -184,10 +188,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD = sc.parallelize(
- Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
- .toDataFrame("class", "version", "algo", "numNodes")
- metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val nodes = model.topNode.subtreeIterator.toSeq
@@ -269,20 +273,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
}
override def load(sc: SparkContext, path: String): DecisionTreeModel = {
+ implicit val formats = DefaultFormats
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
- val (algo: String, numNodes: Int) = try {
- val algo_numNodes = metadata.select("algo", "numNodes").collect()
- assert(algo_numNodes.length == 1)
- algo_numNodes(0) match {
- case Row(a: String, n: Int) => (a, n)
- }
- } catch {
- // Catch both Error and Exception since the checks above can throw either.
- case e: Throwable =>
- throw new Exception(
- s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
- + s" Error message: ${e.getMessage}")
- }
+ val algo = (metadata \ "algo").extract[String]
+ val numNodes = (metadata \ "numNodes").extract[Int]
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 23bd46baab..dbd69dca60 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -20,18 +20,20 @@ package org.apache.spark.mllib.tree.model
import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Algo
+import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
-import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
-
+import org.apache.spark.sql.SQLContext
/**
* :: Experimental ::
@@ -59,11 +61,11 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis
object RandomForestModel extends Loader[RandomForestModel] {
override def load(sc: SparkContext, path: String): RandomForestModel = {
- val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+ val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
assert(metadata.treeWeights.forall(_ == 1.0))
val trees =
TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
@@ -110,11 +112,11 @@ class GradientBoostedTreesModel(
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
- val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+ val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
assert(metadata.combiningStrategy == Sum.toString)
val trees =
TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
@@ -252,7 +254,7 @@ private[tree] object TreeEnsembleModel {
object SaveLoadV1_0 {
- import DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
+ import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
def thisFormatVersion = "1.0"
@@ -276,11 +278,13 @@ private[tree] object TreeEnsembleModel {
import sqlContext.implicits._
// Create JSON metadata.
- val metadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
+ implicit val format = DefaultFormats
+ val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
model.combiningStrategy.toString, model.treeWeights)
- val metadataRDD = sc.parallelize(Seq((className, thisFormatVersion, metadata)), 1)
- .toDataFrame("class", "version", "metadata")
- metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> className) ~ ("version" -> thisFormatVersion) ~
+ ("metadata" -> Extraction.decompose(ensembleMetadata))))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
@@ -290,24 +294,11 @@ private[tree] object TreeEnsembleModel {
}
/**
- * Read metadata from the loaded metadata DataFrame.
- * @param path Path for loading data, used for debug messages.
+ * Read metadata from the loaded JSON metadata.
*/
- def readMetadata(metadata: DataFrame, path: String): Metadata = {
- try {
- // We rely on the try-catch for schema checking rather than creating a schema just for this.
- val metadataArray = metadata.select("metadata.algo", "metadata.treeAlgo",
- "metadata.combiningStrategy", "metadata.treeWeights").collect()
- assert(metadataArray.size == 1)
- Metadata(metadataArray(0).getString(0), metadataArray(0).getString(1),
- metadataArray(0).getString(2), metadataArray(0).getAs[Seq[Double]](3).toArray)
- } catch {
- // Catch both Error and Exception since the checks above can throw either.
- case e: Throwable =>
- throw new Exception(
- s"Unable to load TreeEnsembleModel metadata from: ${Loader.metadataPath(path)}."
- + s" Error message: ${e.getMessage}")
- }
+ def readMetadata(metadata: JValue): Metadata = {
+ implicit val formats = DefaultFormats
+ (metadata \ "metadata").extract[Metadata]
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
index 56b77a7d12..4458340497 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.util
import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{DataType, StructType, StructField}
-
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
* :: DeveloperApi ::
@@ -120,20 +120,11 @@ private[mllib] object Loader {
* Load metadata from the given path.
* @return (class name, version, metadata)
*/
- def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = {
- val sqlContext = new SQLContext(sc)
- val metadata = sqlContext.jsonFile(metadataPath(path))
- val (clazz, version) = try {
- val metadataArray = metadata.select("class", "version").take(1)
- assert(metadataArray.size == 1)
- metadataArray(0) match {
- case Row(clazz: String, version: String) => (clazz, version)
- }
- } catch {
- case e: Exception =>
- throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}")
- }
+ def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = {
+ implicit val formats = DefaultFormats
+ val metadata = parse(sc.textFile(metadataPath(path)).first())
+ val clazz = (metadata \ "class").extract[String]
+ val version = (metadata \ "version").extract[String]
(clazz, version, metadata)
}
-
}