diff options
author | Jayant Shekar <jayant@user-MBPMBA-3.local> | 2015-10-23 08:45:13 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-10-23 08:45:13 -0700 |
commit | 4e38defae13b2b13e196b4d172722ef5e6266c66 (patch) | |
tree | 727d246ccd43d9860249a947aee76984c2ab930f /mllib/src/main | |
parent | 282a15f78e08f0dc9e696945be4fc973011a96d9 (diff) | |
download | spark-4e38defae13b2b13e196b4d172722ef5e6266c66.tar.gz spark-4e38defae13b2b13e196b4d172722ef5e6266c66.tar.bz2 spark-4e38defae13b2b13e196b4d172722ef5e6266c66.zip |
[SPARK-6723] [MLLIB] Model import/export for ChiSqSelector
This is a PR for Parquet-based model import/export.
* Added save/load for ChiSqSelectorModel
* Updated the test suite ChiSqSelectorSuite
Author: Jayant Shekar <jayant@user-MBPMBA-3.local>
Closes #6785 from jayantshekhar/SPARK-6723.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala | 70 |
1 files changed, 69 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index b1524cf377..5246faf221 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: @@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD @Since("1.3.0") @Experimental class ChiSqSelectorModel @Since("1.3.0") ( - @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") ( s"Only sparse and dense vectors are supported but got ${other.getClass}.") } } + + @Since("1.6.0") + override def save(sc: SparkContext, path: String): Unit = { + ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { + @Since("1.6.0") + override def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + ChiSqSelectorModel.SaveLoadV1_0.load(sc, path) + } + + private[feature] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + /** Model data for import/export */ + case class Data(feature: Int) + + private[feature] + val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" + + def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Create Parquet data. + val dataArray = Array.tabulate(model.selectedFeatures.length) { i => + Data(model.selectedFeatures(i)) + } + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) + + } + + def load(sc: SparkContext, path: String): ChiSqSelectorModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val dataFrame = sqlContext.read.parquet(Loader.dataPath(path)) + val dataArray = dataFrame.select("feature") + + // Check schema explicitly since erasure makes it hard to use match-case for checking. + Loader.checkSchema[Data](dataFrame.schema) + + val features = dataArray.map { + case Row(feature: Int) => (feature) + }.collect() + + return new ChiSqSelectorModel(features) + } + } } /** |