aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-05-19 23:35:20 -0700
committerXiangrui Meng <meng@databricks.com>2016-05-19 23:35:20 -0700
commitc94b34ebbf4c6ce353c899c571beb34e8db98917 (patch)
treea742c44515259359153599ee62f6aa0e6bd58e91 /mllib/src/test
parent5e203505f1a092e5849ebd01d9ff9e4fc6cdc34a (diff)
downloadspark-c94b34ebbf4c6ce353c899c571beb34e8db98917.tar.gz
spark-c94b34ebbf4c6ce353c899c571beb34e8db98917.tar.bz2
spark-c94b34ebbf4c6ce353c899c571beb34e8db98917.zip
[SPARK-15339][ML] ML 2.0 QA: Scala APIs and code audit for regression
## What changes were proposed in this pull request? * ```GeneralizedLinearRegression``` API docs enhancement. * The default value of ```GeneralizedLinearRegression``` ```linkPredictionCol``` is not set rather than empty. This will consistent with other similar params such as ```weightCol``` * Make some methods more private. * Fix a minor bug of LinearRegression. * Fix some other issues. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #13129 from yanboliang/spark-15339.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala15
1 files changed, 13 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 332d331a47..265f2f45c4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -610,20 +610,31 @@ class LinearRegressionSuite
val model1 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
+ .setPredictionCol("myPrediction")
.setSolver(solver)
.fit(datasetWithWeightConstantLabel)
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
model1.coefficients(1))
assert(actual1 ~== expected(idx) absTol 1e-4)
+ // Schema of summary.predictions should be a superset of the input dataset
+ assert((datasetWithWeightConstantLabel.schema.fieldNames.toSet + model1.getPredictionCol)
+ .subsetOf(model1.summary.predictions.schema.fieldNames.toSet))
+
val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
+ .setPredictionCol("myPrediction")
.setSolver(solver)
.fit(datasetWithWeightZeroLabel)
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
model2.coefficients(1))
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
+
+ // Schema of summary.predictions should be a superset of the input dataset
+ assert((datasetWithWeightZeroLabel.schema.fieldNames.toSet + model2.getPredictionCol)
+ .subsetOf(model2.summary.predictions.schema.fieldNames.toSet))
+
idx += 1
}
}
@@ -672,7 +683,7 @@ class LinearRegressionSuite
test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
- val trainer = new LinearRegression().setSolver(solver)
+ val trainer = new LinearRegression().setSolver(solver).setPredictionCol("myPrediction")
val model = trainer.fit(datasetWithDenseFeature)
val trainerNoPredictionCol = trainer.setPredictionCol("")
val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature)
@@ -682,7 +693,7 @@ class LinearRegressionSuite
assert(modelNoPredictionCol.hasSummary)
// Schema should be a superset of the input dataset
- assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf(
+ assert((datasetWithDenseFeature.schema.fieldNames.toSet + model.getPredictionCol).subsetOf(
model.summary.predictions.schema.fieldNames.toSet))
// Validate that we re-insert a prediction column for evaluation
val modelNoPredictionColFieldNames