aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala41
-rw-r--r--project/MimaExcludes.scala3
-rw-r--r--python/pyspark/ml/classification.py218
-rw-r--r--python/pyspark/ml/regression.py245
-rw-r--r--python/pyspark/ml/tests.py87
-rw-r--r--python/pyspark/ml/wrapper.py30
7 files changed, 602 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index aeb94a6600..37182928cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -777,10 +777,10 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Dataframe outputted by the model's `transform` method. */
def predictions: DataFrame
- /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
+ /** Field in "predictions" which gives the calibrated probability of each class as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the true label of each instance. */
+ /** Field in "predictions" which gives the true label of each instance (if available). */
def labelCol: String
/** Field in "predictions" which gives the features of each instance as a vector. */
@@ -794,7 +794,7 @@ sealed trait LogisticRegressionSummary extends Serializable {
*
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance as a vector.
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
@@ -818,7 +818,7 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
*
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance.
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 2633c06f40..9619e72a45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -190,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
- $(featuresCol),
Array(0D))
return lrModel.setSummary(trainingSummary)
@@ -249,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
} else {
@@ -356,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -421,7 +421,7 @@ class LinearRegressionModel private[ml] (
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
- $(labelCol), summaryModel, Array(0D))
+ $(labelCol), $(featuresCol), summaryModel, Array(0D))
}
/**
@@ -511,7 +511,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
/**
* :: Experimental ::
* Linear regression training results. Currently, the training summary ignores the
- * training coefficients except for the objective trace.
+ * training weights except for the objective trace.
*
* @param predictions predictions outputted by the model's `transform` method.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
@@ -522,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ featuresCol: String,
model: LinearRegressionModel,
diagInvAtWA: Array[Double],
- val featuresCol: String,
val objectiveHistory: Array[Double])
- extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) {
+ extends LinearRegressionSummary(
+ predictions,
+ predictionCol,
+ labelCol,
+ featuresCol,
+ model,
+ diagInvAtWA) {
- /** Number of training iterations until termination */
+ /**
+ * Number of training iterations until termination
+ *
+ * This value is only available when using the "l-bfgs" solver.
+ * @see [[LinearRegression.solver]]
+ */
@Since("1.5.0")
val totalIterations = objectiveHistory.length
@@ -539,6 +550,10 @@ class LinearRegressionTrainingSummary private[regression] (
* Linear regression results evaluated on a dataset.
*
* @param predictions predictions outputted by the model's `transform` method.
+ * @param predictionCol Field in "predictions" which gives the predicted value of the label at
+ * each instance.
+ * @param labelCol Field in "predictions" which gives the true label of each instance.
+ * @param featuresCol Field in "predictions" which gives the features of each instance as a vector.
*/
@Since("1.5.0")
@Experimental
@@ -546,6 +561,7 @@ class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
val labelCol: String,
+ val featuresCol: String,
val model: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
@@ -639,6 +655,9 @@ class LinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ *
+ * This value is only available when using the "normal" solver.
+ * @see [[LinearRegression.solver]]
*/
lazy val coefficientStandardErrors: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -660,6 +679,9 @@ class LinearRegressionSummary private[regression] (
/**
* T-statistic of estimated coefficients and intercept.
+ *
+ * This value is only available when using the "normal" solver.
+ * @see [[LinearRegression.solver]]
*/
lazy val tValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -677,6 +699,9 @@ class LinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ *
+ * This value is only available when using the "normal" solver.
+ * @see [[LinearRegression.solver]]
*/
lazy val pValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9f245afd50..d916c49a6a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -610,6 +610,9 @@ object MimaExcludes {
// [SPARK-13674][SQL] Add wholestage codegen support to Sample
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
+ ) ++ Seq(
+ // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 067009559b..be7f9ea9ef 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,15 +19,18 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['LogisticRegression', 'LogisticRegressionModel',
+ 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
+ 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
'GBTClassifier', 'GBTClassificationModel',
'RandomForestClassifier', 'RandomForestClassificationModel',
@@ -233,6 +236,219 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, mse, r-squared ) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_blrt_summary = self._call_java("summary")
+ # Note: Once multiclass is added, update this to return correct summary
+ return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_blr_summary = self._call_java("evaluate", dataset)
+ return BinaryLogisticRegressionSummary(java_blr_summary)
+
+
+class LogisticRegressionSummary(JavaCallable):
+ """
+ Abstraction for Logistic Regression Results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Dataframe outputted by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def probabilityCol(self):
+ """
+ Field in "predictions" which gives the calibrated probability
+ of each class as a vector.
+ """
+ return self._call_java("probabilityCol")
+
+ @property
+ @since("2.0.0")
+ def labelCol(self):
+ """
+ Field in "predictions" which gives the true label of each
+ instance.
+ """
+ return self._call_java("labelCol")
+
+ @property
+ @since("2.0.0")
+ def featuresCol(self):
+ """
+ Field in "predictions" which gives the features of each instance
+ as a vector.
+ """
+ return self._call_java("featuresCol")
+
+
+@inherit_doc
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
+ """
+ Abstraction for multinomial Logistic Regression Training results.
+ Currently, the training summary ignores the training weights except
+ for the objective trace.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def objectiveHistory(self):
+ """
+ Objective function (scaled loss + regularization) at each
+ iteration.
+ """
+ return self._call_java("objectiveHistory")
+
+ @property
+ @since("2.0.0")
+ def totalIterations(self):
+ """
+ Number of training iterations until termination.
+ """
+ return self._call_java("totalIterations")
+
+
+@inherit_doc
+class BinaryLogisticRegressionSummary(LogisticRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Binary Logistic regression results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def roc(self):
+ """
+ Returns the receiver operating characteristic (ROC) curve,
+ which is an Dataframe having two fields (FPR, TPR) with
+ (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+ Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("roc")
+
+ @property
+ @since("2.0.0")
+ def areaUnderROC(self):
+ """
+ Computes the area under the receiver operating characteristic
+ (ROC) curve.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("areaUnderROC")
+
+ @property
+ @since("2.0.0")
+ def pr(self):
+ """
+ Returns the precision-recall curve, which is an Dataframe
+ containing two fields recall, precision with (0.0, 1.0) prepended
+ to it.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("pr")
+
+ @property
+ @since("2.0.0")
+ def fMeasureByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, F-Measure) curve
+ with beta = 1.0.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("fMeasureByThreshold")
+
+ @property
+ @since("2.0.0")
+ def precisionByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, precision) curve.
+ Every possible probability obtained in transforming the dataset
+ are used as thresholds used in calculating the precision.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("precisionByThreshold")
+
+ @property
+ @since("2.0.0")
+ def recallByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, recall) curve.
+ Every possible probability obtained in transforming the dataset
+ are used as thresholds used in calculating the recall.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("recallByThreshold")
+
+
+@inherit_doc
+class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
+ LogisticRegressionTrainingSummary):
+ """
+ .. note:: Experimental
+
+ Binary Logistic regression training results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+ pass
+
class TreeClassifierParams(object):
"""
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index de8a5e4bed..6cd1b4bf3a 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,8 +20,9 @@ import warnings
from pyspark import since
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
from pyspark.mllib.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@@ -29,6 +30,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel',
+ 'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
'RandomForestRegressor', 'RandomForestRegressionModel']
@@ -131,7 +133,6 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model weights.
"""
-
warnings.warn("weights is deprecated. Use coefficients instead.")
return self._call_java("weights")
@@ -151,6 +152,246 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, mse, r-squared ) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_lrt_summary = self._call_java("summary")
+ return LinearRegressionTrainingSummary(java_lrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_lr_summary = self._call_java("evaluate", dataset)
+ return LinearRegressionSummary(java_lr_summary)
+
+
+class LinearRegressionSummary(JavaCallable):
+ """
+ .. note:: Experimental
+
+ Linear regression results evaluated on a dataset.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Dataframe outputted by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def predictionCol(self):
+ """
+ Field in "predictions" which gives the predicted value of
+ the label at each instance.
+ """
+ return self._call_java("predictionCol")
+
+ @property
+ @since("2.0.0")
+ def labelCol(self):
+ """
+ Field in "predictions" which gives the true label of each
+ instance.
+ """
+ return self._call_java("labelCol")
+
+ @property
+ @since("2.0.0")
+ def featuresCol(self):
+ """
+ Field in "predictions" which gives the features of each instance
+ as a vector.
+ """
+ return self._call_java("featuresCol")
+
+ @property
+ @since("2.0.0")
+ def explainedVariance(self):
+ """
+ Returns the explained variance regression score.
+ explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+ Reference: http://en.wikipedia.org/wiki/Explained_variation
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("explainedVariance")
+
+ @property
+ @since("2.0.0")
+ def meanAbsoluteError(self):
+ """
+ Returns the mean absolute error, which is a risk function
+ corresponding to the expected value of the absolute error
+ loss or l1-norm loss.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("meanAbsoluteError")
+
+ @property
+ @since("2.0.0")
+ def meanSquaredError(self):
+ """
+ Returns the mean squared error, which is a risk function
+ corresponding to the expected value of the squared error
+ loss or quadratic loss.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("meanSquaredError")
+
+ @property
+ @since("2.0.0")
+ def rootMeanSquaredError(self):
+ """
+ Returns the root mean squared error, which is defined as the
+ square root of the mean squared error.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("rootMeanSquaredError")
+
+ @property
+ @since("2.0.0")
+ def r2(self):
+ """
+ Returns R^2^, the coefficient of determination.
+ Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("r2")
+
+ @property
+ @since("2.0.0")
+ def residuals(self):
+ """
+ Residuals (label - predicted value)
+ """
+ return self._call_java("residuals")
+
+ @property
+ @since("2.0.0")
+ def numInstances(self):
+ """
+ Number of instances in DataFrame predictions
+ """
+ return self._call_java("numInstances")
+
+ @property
+ @since("2.0.0")
+ def devianceResiduals(self):
+ """
+ The weighted residuals, the usual residuals rescaled by the
+ square root of the instance weights.
+ """
+ return self._call_java("devianceResiduals")
+
+ @property
+ @since("2.0.0")
+ def coefficientStandardErrors(self):
+ """
+ Standard error of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("coefficientStandardErrors")
+
+ @property
+ @since("2.0.0")
+ def tValues(self):
+ """
+ T-statistic of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("tValues")
+
+ @property
+ @since("2.0.0")
+ def pValues(self):
+ """
+ Two-sided p-value of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("pValues")
+
+
+@inherit_doc
+class LinearRegressionTrainingSummary(LinearRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Linear regression training results. Currently, the training summary ignores the
+ training weights except for the objective trace.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def objectiveHistory(self):
+ """
+ Objective function (scaled loss + regularization) at each
+ iteration.
+ This value is only available when using the "l-bfgs" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("objectiveHistory")
+
+ @property
+ @since("2.0.0")
+ def totalIterations(self):
+ """
+ Number of training iterations until termination.
+ This value is only available when using the "l-bfgs" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("totalIterations")
+
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e3f873e3a7..2dcd5eeb52 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -239,6 +239,17 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
return self._set(**kwargs)
+class HasThrowableProperty(Params):
+
+ def __init__(self):
+ super(HasThrowableProperty, self).__init__()
+ self.p = Param(self, "none", "empty param")
+
+ @property
+ def test_property(self):
+ raise RuntimeError("Test property to raise error when invoked")
+
+
class ParamTests(PySparkTestCase):
def test_copy_new_parent(self):
@@ -749,15 +760,75 @@ class PersistenceTest(PySparkTestCase):
pass
-class HasThrowableProperty(Params):
-
- def __init__(self):
- super(HasThrowableProperty, self).__init__()
- self.p = Param(self, "none", "empty param")
+class TrainingSummaryTest(PySparkTestCase):
- @property
- def test_property(self):
- raise RuntimeError("Test property to raise error when invoked")
+ def test_linear_regression_summary(self):
+ from pyspark.mllib.linalg import Vectors
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight",
+ fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertAlmostEqual(s.explainedVariance, 0.25, 2)
+ self.assertAlmostEqual(s.meanAbsoluteError, 0.0)
+ self.assertAlmostEqual(s.meanSquaredError, 0.0)
+ self.assertAlmostEqual(s.rootMeanSquaredError, 0.0)
+ self.assertAlmostEqual(s.r2, 1.0, 2)
+ self.assertTrue(isinstance(s.residuals, DataFrame))
+ self.assertEqual(s.numInstances, 2)
+ devResiduals = s.devianceResiduals
+ self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float))
+ coefStdErr = s.coefficientStandardErrors
+ self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
+ tValues = s.tValues
+ self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
+ pValues = s.pValues
+ self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance)
+
+ def test_logistic_regression_summary(self):
+ from pyspark.mllib.linalg import Vectors
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.roc, DataFrame))
+ self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
+ self.assertTrue(isinstance(s.pr, DataFrame))
+ self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
if __name__ == "__main__":
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index ca93bf7d7d..a2cf2296fb 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -213,8 +213,30 @@ class JavaTransformer(Transformer, JavaWrapper):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
+class JavaCallable(object):
+ """
+ Wrapper for a plain object in JVM to make Java calls, can be used
+ as a mixin to another class that defines a _java_obj wrapper
+ """
+ def __init__(self, java_obj=None, sc=None):
+ super(JavaCallable, self).__init__()
+ self._sc = sc if sc is not None else SparkContext._active_spark_context
+ # if this class is a mixin and _java_obj is already defined then don't initialize
+ if java_obj is not None or not hasattr(self, "_java_obj"):
+ self._java_obj = java_obj
+
+ def __del__(self):
+ if self._java_obj is not None:
+ self._sc._gateway.detach(self._java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ java_args = [_py2java(self._sc, arg) for arg in args]
+ return _java2py(self._sc, m(*java_args))
+
+
@inherit_doc
-class JavaModel(Model, JavaTransformer):
+class JavaModel(Model, JavaCallable, JavaTransformer):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -259,9 +281,3 @@ class JavaModel(Model, JavaTransformer):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- sc = SparkContext._active_spark_context
- java_args = [_py2java(sc, arg) for arg in args]
- return _java2py(sc, m(*java_args))