aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-02-15 01:15:50 -0800
committerFelix Cheung <felixcheung@apache.org>2017-02-15 01:15:50 -0800
commit3973403d5d90a48e3a995159680239ba5240e30c (patch)
tree3b0afd2edbd4e651d22efa7926fa31ec47516469 /mllib
parent447b2b5309251f3ae37857de73c157e59a0d76df (diff)
downloadspark-3973403d5d90a48e3a995159680239ba5240e30c.tar.gz
spark-3973403d5d90a48e3a995159680239ba5240e30c.tar.bz2
spark-3973403d5d90a48e3a995159680239ba5240e30c.zip
[SPARK-19456][SPARKR] Add LinearSVC R API
## What changes were proposed in this pull request? Linear SVM classifier is newly added into ML and python API has been added. This JIRA is to add R side API. Marked as WIP, as I am designing unit tests. ## How was this patch tested? Please review http://spark.apache.org/contributing.html before opening a pull request. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16800 from wangmiao1981/svc.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala152
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
2 files changed, 154 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
new file mode 100644
index 0000000000..cfd043b66e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.classification.{LinearSVC, LinearSVCModel}
+import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.r.RWrapperUtils._
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class LinearSVCWrapper private (
+ val pipeline: PipelineModel,
+ val features: Array[String],
+ val labels: Array[String]) extends MLWritable {
+ import LinearSVCWrapper._
+
+ private val svcModel: LinearSVCModel =
+ pipeline.stages(1).asInstanceOf[LinearSVCModel]
+
+ lazy val coefficients: Array[Double] = svcModel.coefficients.toArray
+
+ lazy val intercept: Double = svcModel.intercept
+
+ lazy val numClasses: Int = svcModel.numClasses
+
+ lazy val numFeatures: Int = svcModel.numFeatures
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(svcModel.getFeaturesCol)
+ .drop(svcModel.getLabelCol)
+ }
+
+ override def write: MLWriter = new LinearSVCWrapper.LinearSVCWrapperWriter(this)
+}
+
+private[r] object LinearSVCWrapper
+ extends MLReadable[LinearSVCWrapper] {
+
+ val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
+ val PREDICTED_LABEL_COL = "prediction"
+
+ def fit(
+ data: DataFrame,
+ formula: String,
+ regParam: Double,
+ maxIter: Int,
+ tol: Double,
+ standardization: Boolean,
+ threshold: Double,
+ weightCol: String,
+ aggregationDepth: Int
+ ): LinearSVCWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ .setForceIndexLabel(true)
+ checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ val fitIntercept = rFormula.hasIntercept
+
+ // get labels and feature names from output schema
+ val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
+
+ // assemble and fit the pipeline
+ val svc = new LinearSVC()
+ .setRegParam(regParam)
+ .setMaxIter(maxIter)
+ .setTol(tol)
+ .setFitIntercept(fitIntercept)
+ .setStandardization(standardization)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+ .setLabelCol(rFormula.getLabelCol)
+ .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
+ .setThreshold(threshold)
+ .setAggregationDepth(aggregationDepth)
+
+ if (weightCol != null) svc.setWeightCol(weightCol)
+
+ val idxToStr = new IndexToString()
+ .setInputCol(PREDICTED_LABEL_INDEX_COL)
+ .setOutputCol(PREDICTED_LABEL_COL)
+ .setLabels(labels)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, svc, idxToStr))
+ .fit(data)
+
+ new LinearSVCWrapper(pipeline, features, labels)
+ }
+
+ override def read: MLReader[LinearSVCWrapper] = new LinearSVCWrapperReader
+
+ override def load(path: String): LinearSVCWrapper = super.load(path)
+
+ class LinearSVCWrapperWriter(instance: LinearSVCWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("features" -> instance.features.toSeq) ~
+ ("labels" -> instance.labels.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class LinearSVCWrapperReader extends MLReader[LinearSVCWrapper] {
+
+ override def load(path: String): LinearSVCWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val features = (rMetadata \ "features").extract[Array[String]]
+ val labels = (rMetadata \ "labels").extract[Array[String]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new LinearSVCWrapper(pipeline, features, labels)
+ }
+ }
+}
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index c44179281b..358e522dfe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -66,6 +66,8 @@ private[r] object RWrappers extends MLReader[Object] {
GBTClassifierWrapper.load(path)
case "org.apache.spark.ml.r.BisectingKMeansWrapper" =>
BisectingKMeansWrapper.load(path)
+ case "org.apache.spark.ml.r.LinearSVCWrapper" =>
+ LinearSVCWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}