aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-06 12:07:47 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 12:07:47 -0700
commit9c6556c5f8ab013b36312db4bf02c4c6d965a535 (patch)
treee4200c088c376f26f27de4f3a96c99006dd99b20 /mllib/src/main
parentbb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 (diff)
downloadspark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.gz
spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.tar.bz2
spark-9c6556c5f8ab013b36312db4bf02c4c6d965a535.zip
[SPARK-13430][PYSPARK][ML] Python API for training summaries of linear and logistic regression
## What changes were proposed in this pull request? Adding Python API for training summaries of LogisticRegression and LinearRegression in PySpark ML. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes. Also, manually verified values are expected and match those from Scala directly. Author: Bryan Cutler <cutlerb@gmail.com> Closes #11621 from BryanCutler/pyspark-ml-summary-SPARK-13430.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala41
2 files changed, 37 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index aeb94a6600..37182928cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -777,10 +777,10 @@ sealed trait LogisticRegressionSummary extends Serializable {
/** Dataframe outputted by the model's `transform` method. */
def predictions: DataFrame
- /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
+ /** Field in "predictions" which gives the calibrated probability of each class as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the true label of each instance. */
+ /** Field in "predictions" which gives the true label of each instance (if available). */
def labelCol: String
/** Field in "predictions" which gives the features of each instance as a vector. */
@@ -794,7 +794,7 @@ sealed trait LogisticRegressionSummary extends Serializable {
*
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance as a vector.
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
@@ -818,7 +818,7 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
*
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance.
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 2633c06f40..9619e72a45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -190,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
- $(featuresCol),
Array(0D))
return lrModel.setSummary(trainingSummary)
@@ -249,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
} else {
@@ -356,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -421,7 +421,7 @@ class LinearRegressionModel private[ml] (
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
- $(labelCol), summaryModel, Array(0D))
+ $(labelCol), $(featuresCol), summaryModel, Array(0D))
}
/**
@@ -511,7 +511,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
/**
* :: Experimental ::
* Linear regression training results. Currently, the training summary ignores the
- * training coefficients except for the objective trace.
+ * training weights except for the objective trace.
*
* @param predictions predictions outputted by the model's `transform` method.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
@@ -522,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ featuresCol: String,
model: LinearRegressionModel,
diagInvAtWA: Array[Double],
- val featuresCol: String,
val objectiveHistory: Array[Double])
- extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) {
+ extends LinearRegressionSummary(
+ predictions,
+ predictionCol,
+ labelCol,
+ featuresCol,
+ model,
+ diagInvAtWA) {
- /** Number of training iterations until termination */
+ /**
+ * Number of training iterations until termination
+ *
+ * This value is only available when using the "l-bfgs" solver.
+ * @see [[LinearRegression.solver]]
+ */
@Since("1.5.0")
val totalIterations = objectiveHistory.length
@@ -539,6 +550,10 @@ class LinearRegressionTrainingSummary private[regression] (
* Linear regression results evaluated on a dataset.
*
* @param predictions predictions outputted by the model's `transform` method.
+ * @param predictionCol Field in "predictions" which gives the predicted value of the label at
+ * each instance.
+ * @param labelCol Field in "predictions" which gives the true label of each instance.
+ * @param featuresCol Field in "predictions" which gives the features of each instance as a vector.
*/
@Since("1.5.0")
@Experimental
@@ -546,6 +561,7 @@ class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
val labelCol: String,
+ val featuresCol: String,
val model: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
@@ -639,6 +655,9 @@ class LinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ *
+ * This value is only available when using the "normal" solver.
+ * @see [[LinearRegression.solver]]
*/
lazy val coefficientStandardErrors: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -660,6 +679,9 @@ class LinearRegressionSummary private[regression] (
/**
* T-statistic of estimated coefficients and intercept.
+ *
+ * This value is only available when using the "normal" solver.
+ * @see [[LinearRegression.solver]]
*/
lazy val tValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -677,6 +699,9 @@ class LinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ *
+ * This value is only available when using the "normal" solver.
+ * @see [[LinearRegression.solver]]
*/
lazy val pValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {