aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala19
1 files changed, 14 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index bd45d21e8d..eb19d13093 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -61,9 +61,9 @@ class LinearRegressionSuite
val featureSize = 4100
datasetWithSparseFeature = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
- intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray,
- xMean = Seq.fill(featureSize)(r.nextDouble).toArray,
- xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200,
+ intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray,
+ xMean = Seq.fill(featureSize)(r.nextDouble()).toArray,
+ xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200,
seed, eps = 0.1, sparsity = 0.7), 2))
/*
@@ -687,7 +687,7 @@ class LinearRegressionSuite
// Validate that we re-insert a prediction column for evaluation
val modelNoPredictionColFieldNames
= modelNoPredictionCol.summary.predictions.schema.fieldNames
- assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf(
+ assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf(
modelNoPredictionColFieldNames.toSet))
assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
@@ -759,7 +759,7 @@ class LinearRegressionSuite
.sliding(2)
.forall(x => x(0) >= x(1)))
} else {
- // To clalify that the normal solver is used here.
+ // To clarify that the normal solver is used here.
assert(model.summary.objectiveHistory.length == 1)
assert(model.summary.objectiveHistory(0) == 0.0)
val devianceResidualsR = Array(-0.47082, 0.34635)
@@ -1006,6 +1006,15 @@ class LinearRegressionSuite
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val lr = new LinearRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
+ lr, isClassification = false, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+ }
}
object LinearRegressionSuite {