aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala4
-rw-r--r--python/pyspark/ml/classification.py131
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py2
-rw-r--r--python/pyspark/ml/param/shared.py24
4 files changed, 158 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index c4e93bf5e8..3b14c4b004 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -114,7 +114,7 @@ class LinearSVC @Since("2.2.0") (
setDefault(standardization -> true)
/**
- * Sets the value of param [[weightCol]].
+ * Set the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
@@ -421,7 +421,7 @@ private class LinearSVCCostFun(
/**
* LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
- * in binary classification for instances in sparse or dense vector in a online fashion.
+ * in binary classification for instances in sparse or dense vector in an online fashion.
*
* Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 5fe4bab186..f10556ca92 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -31,7 +31,8 @@ from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
-__all__ = ['LogisticRegression', 'LogisticRegressionModel',
+__all__ = ['LinearSVC', 'LinearSVCModel',
+ 'LogisticRegression', 'LogisticRegressionModel',
'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
@@ -60,6 +61,134 @@ class JavaClassificationModel(JavaPredictionModel):
@inherit_doc
+class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
+ HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization,
+ HasThreshold, HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
+ """
+ `Linear SVM Classifier <https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM>`_
+ This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
+
+ >>> from pyspark.sql import Row
+ >>> from pyspark.ml.linalg import Vectors
+ >>> df = sc.parallelize([
+ ... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
+ ... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
+ >>> svm = LinearSVC(maxIter=5, regParam=0.01)
+ >>> model = svm.fit(df)
+ >>> model.coefficients
+ DenseVector([0.0, -0.2792, -0.1833])
+ >>> model.intercept
+ 1.0206118982229047
+ >>> model.numClasses
+ 2
+ >>> model.numFeatures
+ 3
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()
+ >>> result = model.transform(test0).head()
+ >>> result.prediction
+ 1.0
+ >>> result.rawPrediction
+ DenseVector([-1.4831, 1.4831])
+ >>> svm.setParams("vector")
+ Traceback (most recent call last):
+ ...
+ TypeError: Method setParams forces keyword arguments.
+ >>> svm_path = temp_path + "/svm"
+ >>> svm.save(svm_path)
+ >>> svm2 = LinearSVC.load(svm_path)
+ >>> svm2.getMaxIter()
+ 5
+ >>> model_path = temp_path + "/svm_model"
+ >>> model.save(model_path)
+ >>> model2 = LinearSVCModel.load(model_path)
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+ >>> model.intercept == model2.intercept
+ True
+
+ .. versionadded:: 2.2.0
+ """
+
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
+ fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
+ aggregationDepth=2):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
+ fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
+ aggregationDepth=2):
+ """
+ super(LinearSVC, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.classification.LinearSVC", self.uid)
+ self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True,
+ standardization=True, threshold=0.0, aggregationDepth=2)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.2.0")
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction",
+ fitIntercept=True, standardization=True, threshold=0.0, weightCol=None,
+ aggregationDepth=2):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol="rawPrediction", \
+ fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \
+ aggregationDepth=2):
+ Sets params for Linear SVM Classifier.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return LinearSVCModel(java_model)
+
+
+class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable):
+ """
+ Model fitted by LinearSVC.
+
+ .. versionadded:: 2.2.0
+ """
+
+ @property
+ @since("2.2.0")
+ def coefficients(self):
+ """
+ Model coefficients of Linear SVM Classifier.
+ """
+ return self._call_java("coefficients")
+
+ @property
+ @since("2.2.0")
+ def intercept(self):
+ """
+ Model intercept of Linear SVM Classifier.
+ """
+ return self._call_java("intercept")
+
+ @property
+ @since("2.2.0")
+ def numClasses(self):
+ """
+ Number of classes.
+ """
+ return self._call_java("numClasses")
+
+ @property
+ @since("2.2.0")
+ def numFeatures(self):
+ """
+ Number of features.
+ """
+ return self._call_java("numFeatures")
+
+
+@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 929591236d..51d49b524c 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -143,6 +143,8 @@ if __name__ == "__main__":
"The class with largest value p/t is predicted, where p is the original " +
"probability of that class and t is the class's threshold.", None,
"TypeConverters.toListFloat"),
+ ("threshold", "threshold in binary classification prediction, in range [0, 1]",
+ "0.5", "TypeConverters.toFloat"),
("weightCol", "weight column name. If this is not set or empty, we treat " +
"all instance weights as 1.0.", None, "TypeConverters.toString"),
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index cc596936d8..163a0e2b3a 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -490,6 +490,30 @@ class HasThresholds(Params):
return self.getOrDefault(self.thresholds)
+class HasThreshold(Params):
+ """
+ Mixin for param threshold: threshold in binary classification prediction, in range [0, 1]
+ """
+
+ threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]", typeConverter=TypeConverters.toFloat)
+
+ def __init__(self):
+ super(HasThreshold, self).__init__()
+ self._setDefault(threshold=0.5)
+
+ def setThreshold(self, value):
+ """
+ Sets the value of :py:attr:`threshold`.
+ """
+ return self._set(threshold=value)
+
+ def getThreshold(self):
+ """
+ Gets the value of threshold or its default value.
+ """
+ return self.getOrDefault(self.threshold)
+
+
class HasWeightCol(Params):
"""
Mixin for param weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0.