diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bd45d21e8d..eb19d13093 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -61,9 +61,9 @@ class LinearRegressionSuite val featureSize = 4100 datasetWithSparseFeature = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray, - xMean = Seq.fill(featureSize)(r.nextDouble).toArray, - xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, + intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, + xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, + xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, seed, eps = 0.1, sparsity = 0.7), 2)) /* @@ -687,7 +687,7 @@ class LinearRegressionSuite // Validate that we re-insert a prediction column for evaluation val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames - assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf( + assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf( modelNoPredictionColFieldNames.toSet)) assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) @@ -759,7 +759,7 @@ class LinearRegressionSuite .sliding(2) .forall(x => x(0) >= x(1))) } else { - // To clalify that the normal solver is used here. + // To clarify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) val devianceResidualsR = Array(-0.47082, 0.34635) @@ -1006,6 +1006,15 @@ class LinearRegressionSuite testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val lr = new LinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } } object LinearRegressionSuite { |