diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index a2f009310f..0ddf097a6e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } + test("Model copy and uid checks") { + val mh = new MinHashLSH() + .setInputCol("keys") + .setOutputCol("values") + val model = mh.fit(dataset) + assert(mh.uid === model.uid) + MLTestingUtils.checkCopy(model) + } + test("hashFunction") { val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0))) val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))) |