aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala70
1 files changed, 69 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index b1524cf377..5246faf221 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature
import scala.collection.mutable.ArrayBuilder
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
@@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD
@Since("1.3.0")
@Experimental
class ChiSqSelectorModel @Since("1.3.0") (
- @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer {
+ @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
require(isSorted(selectedFeatures), "Array has to be sorted asc")
@@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") (
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
+
+ @Since("1.6.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
+ @Since("1.6.0")
+ override def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
+ ChiSqSelectorModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[feature]
+ object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ /** Model data for import/export */
+ case class Data(feature: Int)
+
+ private[feature]
+ val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
+
+ def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
+ Data(model.selectedFeatures(i))
+ }
+ sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
+
+ }
+
+ def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
+ implicit val formats = DefaultFormats
+ val sqlContext = new SQLContext(sc)
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val dataFrame = sqlContext.read.parquet(Loader.dataPath(path))
+ val dataArray = dataFrame.select("feature")
+
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val features = dataArray.map {
+ case Row(feature: Int) => (feature)
+ }.collect()
+
+ return new ChiSqSelectorModel(features)
+ }
+ }
}
/**