aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJayant Shekar <jayant@user-MBPMBA-3.local>2015-10-23 08:45:13 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-23 08:45:13 -0700
commit4e38defae13b2b13e196b4d172722ef5e6266c66 (patch)
tree727d246ccd43d9860249a947aee76984c2ab930f /mllib/src
parent282a15f78e08f0dc9e696945be4fc973011a96d9 (diff)
downloadspark-4e38defae13b2b13e196b4d172722ef5e6266c66.tar.gz
spark-4e38defae13b2b13e196b4d172722ef5e6266c66.tar.bz2
spark-4e38defae13b2b13e196b4d172722ef5e6266c66.zip
[SPARK-6723] [MLLIB] Model import/export for ChiSqSelector
This is a PR for Parquet-based model import/export. * Added save/load for ChiSqSelectorModel * Updated the test suite ChiSqSelectorSuite Author: Jayant Shekar <jayant@user-MBPMBA-3.local> Closes #6785 from jayantshekhar/SPARK-6723.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala70
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala26
2 files changed, 95 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index b1524cf377..5246faf221 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature
import scala.collection.mutable.ArrayBuilder
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
@@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD
@Since("1.3.0")
@Experimental
class ChiSqSelectorModel @Since("1.3.0") (
- @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer {
+ @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
require(isSorted(selectedFeatures), "Array has to be sorted asc")
@@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") (
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
+
+ @Since("1.6.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
+ @Since("1.6.0")
+ override def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
+ ChiSqSelectorModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[feature]
+ object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ /** Model data for import/export */
+ case class Data(feature: Int)
+
+ private[feature]
+ val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
+
+ def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
+ Data(model.selectedFeatures(i))
+ }
+ sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
+
+ }
+
+ def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
+ implicit val formats = DefaultFormats
+ val sqlContext = new SQLContext(sc)
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val dataFrame = sqlContext.read.parquet(Loader.dataPath(path))
+ val dataArray = dataFrame.select("feature")
+
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val features = dataArray.map {
+ case Row(feature: Int) => (feature)
+ }.collect()
+
+ return new ChiSqSelectorModel(features)
+ }
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index 889727fb55..734800a9af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
}.collect().toSet
assert(filteredData == preFilteredData)
}
+
+ test("model load / save") {
+ val model = ChiSqSelectorSuite.createModel()
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model.save(sc, path)
+ val sameModel = ChiSqSelectorModel.load(sc, path)
+ ChiSqSelectorSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+}
+
+object ChiSqSelectorSuite extends SparkFunSuite {
+
+ def createModel(): ChiSqSelectorModel = {
+ val arr = Array(1, 2, 3, 4)
+ new ChiSqSelectorModel(arr)
+ }
+
+ def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = {
+ assert(a.selectedFeatures.deep == b.selectedFeatures.deep)
+ }
}