diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 8 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 41 | ||||
-rw-r--r-- | project/MimaExcludes.scala | 3 | ||||
-rw-r--r-- | python/pyspark/ml/classification.py | 218 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 245 | ||||
-rw-r--r-- | python/pyspark/ml/tests.py | 87 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 30 |
7 files changed, 602 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index aeb94a6600..37182928cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -777,10 +777,10 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Dataframe outputted by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ + /** Field in "predictions" which gives the calibrated probability of each class as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance (if available). */ def labelCol: String /** Field in "predictions" which gives the features of each instance as a vector. */ @@ -794,7 +794,7 @@ sealed trait LogisticRegressionSummary extends Serializable { * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each instance as a vector. + * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. @@ -818,7 +818,7 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each instance. + * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2633c06f40..9619e72a45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), summaryModel, model.diagInvAtWA.toArray, - $(featuresCol), Array(0D)) return lrModel.setSummary(trainingSummary) @@ -249,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } else { @@ -356,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -421,7 +421,7 @@ class LinearRegressionModel private[ml] ( // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), summaryModel, Array(0D)) + $(labelCol), $(featuresCol), summaryModel, Array(0D)) } /** @@ -511,7 +511,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { /** * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the - * training coefficients except for the objective trace. + * training weights except for the objective trace. * * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. @@ -522,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + featuresCol: String, model: LinearRegressionModel, diagInvAtWA: Array[Double], - val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { + extends LinearRegressionSummary( + predictions, + predictionCol, + labelCol, + featuresCol, + model, + diagInvAtWA) { - /** Number of training iterations until termination */ + /** + * Number of training iterations until termination + * + * This value is only available when using the "l-bfgs" solver. + * @see [[LinearRegression.solver]] + */ @Since("1.5.0") val totalIterations = objectiveHistory.length @@ -539,6 +550,10 @@ class LinearRegressionTrainingSummary private[regression] ( * Linear regression results evaluated on a dataset. * * @param predictions predictions outputted by the model's `transform` method. + * @param predictionCol Field in "predictions" which gives the predicted value of the label at + * each instance. + * @param labelCol Field in "predictions" which gives the true label of each instance. + * @param featuresCol Field in "predictions" which gives the features of each instance as a vector. */ @Since("1.5.0") @Experimental @@ -546,6 +561,7 @@ class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, val labelCol: String, + val featuresCol: String, val model: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { @@ -639,6 +655,9 @@ class LinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. + * + * This value is only available when using the "normal" solver. + * @see [[LinearRegression.solver]] */ lazy val coefficientStandardErrors: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -660,6 +679,9 @@ class LinearRegressionSummary private[regression] ( /** * T-statistic of estimated coefficients and intercept. + * + * This value is only available when using the "normal" solver. + * @see [[LinearRegression.solver]] */ lazy val tValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -677,6 +699,9 @@ class LinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. + * + * This value is only available when using the "normal" solver. + * @see [[LinearRegression.solver]] */ lazy val pValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9f245afd50..d916c49a6a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -610,6 +610,9 @@ object MimaExcludes { // [SPARK-13674][SQL] Add wholestage codegen support to Sample ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") + ) ++ Seq( + // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") ) case v if v.startsWith("1.6") => Seq( diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 067009559b..be7f9ea9ef 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,15 +19,18 @@ import warnings from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary', + 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', @@ -233,6 +236,219 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_blr_summary = self._call_java("evaluate", dataset) + return BinaryLogisticRegressionSummary(java_blr_summary) + + +class LogisticRegressionSummary(JavaCallable): + """ + Abstraction for Logistic Regression Results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def probabilityCol(self): + """ + Field in "predictions" which gives the calibrated probability + of each class as a vector. + """ + return self._call_java("probabilityCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + +@inherit_doc +class LogisticRegressionTrainingSummary(LogisticRegressionSummary): + """ + Abstraction for multinomial Logistic Regression Training results. + Currently, the training summary ignores the training weights except + for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + """ + return self._call_java("totalIterations") + + +@inherit_doc +class BinaryLogisticRegressionSummary(LogisticRegressionSummary): + """ + .. note:: Experimental + + Binary Logistic regression results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def roc(self): + """ + Returns the receiver operating characteristic (ROC) curve, + which is an Dataframe having two fields (FPR, TPR) with + (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("roc") + + @property + @since("2.0.0") + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("areaUnderROC") + + @property + @since("2.0.0") + def pr(self): + """ + Returns the precision-recall curve, which is an Dataframe + containing two fields recall, precision with (0.0, 1.0) prepended + to it. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("pr") + + @property + @since("2.0.0") + def fMeasureByThreshold(self): + """ + Returns a dataframe with two fields (threshold, F-Measure) curve + with beta = 1.0. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("fMeasureByThreshold") + + @property + @since("2.0.0") + def precisionByThreshold(self): + """ + Returns a dataframe with two fields (threshold, precision) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the precision. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("precisionByThreshold") + + @property + @since("2.0.0") + def recallByThreshold(self): + """ + Returns a dataframe with two fields (threshold, recall) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the recall. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("recallByThreshold") + + +@inherit_doc +class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary, + LogisticRegressionTrainingSummary): + """ + .. note:: Experimental + + Binary Logistic regression training results for a given model. + + .. versionadded:: 2.0.0 + """ + pass + class TreeClassifierParams(object): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de8a5e4bed..6cd1b4bf3a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,8 +20,9 @@ import warnings from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @@ -29,6 +30,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'GBTRegressor', 'GBTRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', + 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', 'RandomForestRegressor', 'RandomForestRegressionModel'] @@ -131,7 +133,6 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model weights. """ - warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") @@ -151,6 +152,246 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_lr_summary = self._call_java("evaluate", dataset) + return LinearRegressionSummary(java_lr_summary) + + +class LinearRegressionSummary(JavaCallable): + """ + .. note:: Experimental + + Linear regression results evaluated on a dataset. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def predictionCol(self): + """ + Field in "predictions" which gives the predicted value of + the label at each instance. + """ + return self._call_java("predictionCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns the explained variance regression score. + explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + Reference: http://en.wikipedia.org/wiki/Explained_variation + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("explainedVariance") + + @property + @since("2.0.0") + def meanAbsoluteError(self): + """ + Returns the mean absolute error, which is a risk function + corresponding to the expected value of the absolute error + loss or l1-norm loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanAbsoluteError") + + @property + @since("2.0.0") + def meanSquaredError(self): + """ + Returns the mean squared error, which is a risk function + corresponding to the expected value of the squared error + loss or quadratic loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanSquaredError") + + @property + @since("2.0.0") + def rootMeanSquaredError(self): + """ + Returns the root mean squared error, which is defined as the + square root of the mean squared error. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("rootMeanSquaredError") + + @property + @since("2.0.0") + def r2(self): + """ + Returns R^2^, the coefficient of determination. + Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("r2") + + @property + @since("2.0.0") + def residuals(self): + """ + Residuals (label - predicted value) + """ + return self._call_java("residuals") + + @property + @since("2.0.0") + def numInstances(self): + """ + Number of instances in DataFrame predictions + """ + return self._call_java("numInstances") + + @property + @since("2.0.0") + def devianceResiduals(self): + """ + The weighted residuals, the usual residuals rescaled by the + square root of the instance weights. + """ + return self._call_java("devianceResiduals") + + @property + @since("2.0.0") + def coefficientStandardErrors(self): + """ + Standard error of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("coefficientStandardErrors") + + @property + @since("2.0.0") + def tValues(self): + """ + T-statistic of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("tValues") + + @property + @since("2.0.0") + def pValues(self): + """ + Two-sided p-value of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("pValues") + + +@inherit_doc +class LinearRegressionTrainingSummary(LinearRegressionSummary): + """ + .. note:: Experimental + + Linear regression training results. Currently, the training summary ignores the + training weights except for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("totalIterations") + @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e3f873e3a7..2dcd5eeb52 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -239,6 +239,17 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): return self._set(**kwargs) +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + class ParamTests(PySparkTestCase): def test_copy_new_parent(self): @@ -749,15 +760,75 @@ class PersistenceTest(PySparkTestCase): pass -class HasThrowableProperty(Params): - - def __init__(self): - super(HasThrowableProperty, self).__init__() - self.p = Param(self, "none", "empty param") +class TrainingSummaryTest(PySparkTestCase): - @property - def test_property(self): - raise RuntimeError("Test property to raise error when invoked") + def test_linear_regression_summary(self): + from pyspark.mllib.linalg import Vectors + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertAlmostEqual(s.explainedVariance, 0.25, 2) + self.assertAlmostEqual(s.meanAbsoluteError, 0.0) + self.assertAlmostEqual(s.meanSquaredError, 0.0) + self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) + self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertTrue(isinstance(s.residuals, DataFrame)) + self.assertEqual(s.numInstances, 2) + devResiduals = s.devianceResiduals + self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + + def test_logistic_regression_summary(self): + from pyspark.mllib.linalg import Vectors + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) if __name__ == "__main__": diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index ca93bf7d7d..a2cf2296fb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -213,8 +213,30 @@ class JavaTransformer(Transformer, JavaWrapper): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) +class JavaCallable(object): + """ + Wrapper for a plain object in JVM to make Java calls, can be used + as a mixin to another class that defines a _java_obj wrapper + """ + def __init__(self, java_obj=None, sc=None): + super(JavaCallable, self).__init__() + self._sc = sc if sc is not None else SparkContext._active_spark_context + # if this class is a mixin and _java_obj is already defined then don't initialize + if java_obj is not None or not hasattr(self, "_java_obj"): + self._java_obj = java_obj + + def __del__(self): + if self._java_obj is not None: + self._sc._gateway.detach(self._java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + java_args = [_py2java(self._sc, arg) for arg in args] + return _java2py(self._sc, m(*java_args)) + + @inherit_doc -class JavaModel(Model, JavaTransformer): +class JavaModel(Model, JavaCallable, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +281,3 @@ class JavaModel(Model, JavaTransformer): that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() return that - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - sc = SparkContext._active_spark_context - java_args = [_py2java(sc, arg) for arg in args] - return _java2py(sc, m(*java_args)) |