aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala8
12 files changed, 62 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 2718dd93dc..f8a606d60b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
- val copied = new BisectingKMeansModel(uid, parentModel)
- copyValues(copied, extra)
+ val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(this.parent)
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 8fac63fefb..a0bd66e731 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
- val copied = new GaussianMixtureModel(uid, weights, gaussians)
- copyValues(copied, extra).setParent(this.parent)
+ val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(this.parent)
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 85bb8c93b3..a0d481b294 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -108,8 +108,9 @@ class KMeansModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
- val copied = new KMeansModel(uid, parentModel)
- copyValues(copied, extra)
+ val copied = copyValues(new KMeansModel(uid, parentModel), extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(this.parent)
}
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 8656ecf609..1938e8ecc5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
- copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
- .setParent(parent)
+ val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
+ extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(parent)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 0fdba1cb88..5d1a39f7c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] (
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
validationMetrics.clone())
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
@Since("2.0.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8771fd2e9d..2877285eb4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -141,6 +141,12 @@ class LogisticRegressionSuite
assert(model.getProbabilityCol === "probability")
assert(model.intercept !== 0.0)
assert(model.hasParent)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+ assert(model.hasSummary)
+ val copiedModel = model.copy(ParamMap.empty)
+ assert(copiedModel.hasSummary)
}
test("empty probabilityCol") {
@@ -251,9 +257,6 @@ class LogisticRegressionSuite
mlr.setFitIntercept(false)
val mlrModel = mlr.fit(smallMultinomialDataset)
assert(mlrModel.interceptVector === Vectors.sparse(3, Seq()))
-
- // copied model must have the same parent.
- MLTestingUtils.checkCopy(model)
}
test("logistic regression with setters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index f2368a9f8d..49797d938d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@@ -41,6 +42,13 @@ class BisectingKMeansSuite
assert(bkm.getPredictionCol === "prediction")
assert(bkm.getMaxIter === 20)
assert(bkm.getMinDivisibleClusterSize === 1.0)
+ val model = bkm.setMaxIter(1).fit(dataset)
+
+ // copied model must have the same parent
+ MLTestingUtils.checkCopy(model)
+ assert(model.hasSummary)
+ val copiedModel = model.copy(ParamMap.empty)
+ assert(copiedModel.hasSummary)
}
test("setter/getter") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 003fa6abf6..7165b63ed3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
@@ -43,6 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(gm.getPredictionCol === "prediction")
assert(gm.getMaxIter === 100)
assert(gm.getTol === 0.01)
+ val model = gm.setMaxIter(1).fit(dataset)
+
+ // copied model must have the same parent
+ MLTestingUtils.checkCopy(model)
+ assert(model.hasSummary)
+ val copiedModel = model.copy(ParamMap.empty)
+ assert(copiedModel.hasSummary)
}
test("set parameters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index ca39265355..73972557d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
@@ -47,6 +48,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
assert(kmeans.getInitSteps === 2)
assert(kmeans.getTol === 1e-4)
+ val model = kmeans.setMaxIter(1).fit(dataset)
+
+ // copied model must have the same parent
+ MLTestingUtils.checkCopy(model)
+ assert(model.hasSummary)
+ val copiedModel = model.copy(ParamMap.empty)
+ assert(copiedModel.hasSummary)
}
test("set parameters") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index ac1ef5feb9..111bc97464 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random._
@@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
+ assert(model.hasSummary)
+ val copiedModel = model.copy(ParamMap.empty)
+ assert(copiedModel.hasSummary)
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index c0e8afbf5e..df97d0b2ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
@@ -143,6 +143,9 @@ class LinearRegressionSuite
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
+ assert(model.hasSummary)
+ val copiedModel = model.copy(ParamMap.empty)
+ assert(copiedModel.hasSummary)
model.transform(datasetWithDenseFeature)
.select("label", "prediction")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 87100ae2e3..4463a9b6e5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
@@ -78,6 +78,10 @@ class TrainValidationSplitSuite
.setTrainRatio(0.5)
.setSeed(42L)
val cvModel = cv.fit(dataset)
+
+ // copied model must have the same paren.
+ MLTestingUtils.checkCopy(cvModel)
+
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)