aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYunni <Euler57721@gmail.com>2016-10-28 14:57:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-10-28 14:57:52 -0700
commitac26e9cf27862fbfb97ae18d591606ecf2cd41cf (patch)
tree368b339964686b66e511010c700b2dcb3b467d47 /mllib
parente9746f87d0b553b8115948acb79f7e32c23dfd86 (diff)
downloadspark-ac26e9cf27862fbfb97ae18d591606ecf2cd41cf.tar.gz
spark-ac26e9cf27862fbfb97ae18d591606ecf2cd41cf.tar.bz2
spark-ac26e9cf27862fbfb97ae18d591606ecf2cd41cf.zip
[SPARK-5992][ML] Locality Sensitive Hashing
## What changes were proposed in this pull request? Implement Locality Sensitive Hashing along with approximate nearest neighbors and approximate similarity join based on the [design doc](https://docs.google.com/document/d/1D15DTDMF_UWTTyWqXfG7y76iZalky4QmifUYQ6lH5GM/edit). Detailed changes are as follows: (1) Implement abstract LSH, LSHModel classes as Estimator-Model (2) Implement approxNearestNeighbors and approxSimilarityJoin in the abstract LSHModel (3) Implement Random Projection as LSH subclass for Euclidean distance, Min Hash for Jaccard Distance (4) Implement unit test utility methods including checkLshProperty, checkNearestNeighbor and checkSimilarityJoin Things that will be implemented in a follow-up PR: - Bit Sampling for Hamming Distance, SignRandomProjection for Cosine Distance - PySpark Integration for the scala classes and methods. ## How was this patch tested? Unit test is implemented for all the implemented classes and algorithms. A scalability test on Uber's dataset was performed internally. Tested the methods on [WEX dataset](https://aws.amazon.com/items/2345) from AWS, with the steps and results [here](https://docs.google.com/document/d/19BXg-67U83NVB3M0I84HVBVg3baAVaESD_mrg_-vLro/edit). ## References Gionis, Aristides, Piotr Indyk, and Rajeev Motwani. "Similarity search in high dimensions via hashing." VLDB 7 Sep. 1999: 518-529. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint arXiv:1408.2927 (2014). Author: Yunni <Euler57721@gmail.com> Author: Yun Ni <yunn@uber.com> Closes #15148 from Yunni/SPARK-5992-yunn-lsh.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala313
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala194
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala225
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala153
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala126
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala197
6 files changed, 1208 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
new file mode 100644
index 0000000000..333a8c364a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.Random
+
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.{IntParam, ParamValidators}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+/**
+ * Params for [[LSH]].
+ */
+private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
+ /**
+ * Param for the dimension of LSH OR-amplification.
+ *
+ * In this implementation, we use LSH OR-amplification to reduce the false negative rate. The
+ * higher the dimension is, the lower the false negative rate.
+ * @group param
+ */
+ final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" +
+ "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" +
+ " improves the running performance", ParamValidators.gt(0))
+
+ /** @group getParam */
+ final def getOutputDim: Int = $(outputDim)
+
+ setDefault(outputDim -> 1)
+
+ /**
+ * Transform the Schema for LSH
+ * @param schema The schema of the input dataset without [[outputCol]]
+ * @return A derived schema with [[outputCol]] added
+ */
+ protected[this] final def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ }
+}
+
+/**
+ * Model produced by [[LSH]].
+ */
+private[ml] abstract class LSHModel[T <: LSHModel[T]]
+ extends Model[T] with LSHParams with MLWritable {
+ self: T =>
+
+ /**
+ * The hash function of LSH, mapping a predefined KeyType to a Vector
+ * @return The mapping of LSH function.
+ */
+ protected[ml] val hashFunction: Vector => Vector
+
+ /**
+ * Calculate the distance between two different keys using the distance metric corresponding
+ * to the hashFunction
+ * @param x One input vector in the metric space
+ * @param y One input vector in the metric space
+ * @return The distance between x and y
+ */
+ protected[ml] def keyDistance(x: Vector, y: Vector): Double
+
+ /**
+ * Calculate the distance between two different hash Vectors.
+ *
+ * @param x One of the hash vector
+ * @param y Another hash vector
+ * @return The distance between hash vectors x and y
+ */
+ protected[ml] def hashDistance(x: Vector, y: Vector): Double
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ val transformUDF = udf(hashFunction, new VectorUDT)
+ dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ /**
+ * Given a large dataset and an item, approximately find at most k items which have the closest
+ * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
+ * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
+ * transformed data when necessary.
+ *
+ * This method implements two ways of fetching k nearest neighbors:
+ * - Single Probing: Fast, return at most k elements (Probing only one buckets)
+ * - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key)
+ *
+ * @param dataset the dataset to search for nearest neighbors of the key
+ * @param key Feature vector representing the item to search for
+ * @param numNearestNeighbors The maximum number of nearest neighbors
+ * @param singleProbing True for using Single Probing; false for multiple probing
+ * @param distCol Output column for storing the distance between each result row and the key
+ * @return A dataset containing at most k items closest to the key. A distCol is added to show
+ * the distance between each row and the key.
+ */
+ def approxNearestNeighbors(
+ dataset: Dataset[_],
+ key: Vector,
+ numNearestNeighbors: Int,
+ singleProbing: Boolean,
+ distCol: String): Dataset[_] = {
+ require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
+ // Get Hash Value of the key
+ val keyHash = hashFunction(key)
+ val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
+ transform(dataset)
+ } else {
+ dataset.toDF()
+ }
+
+ // In the origin dataset, find the hash value that is closest to the key
+ val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType)
+ val hashDistCol = hashDistUDF(col($(outputCol)))
+
+ val modelSubset = if (singleProbing) {
+ modelDataset.filter(hashDistCol === 0.0)
+ } else {
+ // Compute threshold to get exact k elements.
+ val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
+ val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
+ val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
+
+ // Filter the dataset where the hash value is less than the threshold.
+ modelDataset.filter(hashDistCol <= hashThreshold)
+ }
+
+ // Get the top k nearest neighbor by their distance to the key
+ val keyDistUDF = udf((x: Vector) => keyDistance(x, key), DataTypes.DoubleType)
+ val modelSubsetWithDistCol = modelSubset.withColumn(distCol, keyDistUDF(col($(inputCol))))
+ modelSubsetWithDistCol.sort(distCol).limit(numNearestNeighbors)
+ }
+
+ /**
+ * Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search
+ * nearest neighbors and "distCol" as default distCol.
+ */
+ def approxNearestNeighbors(
+ dataset: Dataset[_],
+ key: Vector,
+ numNearestNeighbors: Int): Dataset[_] = {
+ approxNearestNeighbors(dataset, key, numNearestNeighbors, true, "distCol")
+ }
+
+ /**
+ * Preprocess step for approximate similarity join. Transform and explode the [[outputCol]] to
+ * two explodeCols: entry and value. "entry" is the index in hash vector, and "value" is the
+ * value of corresponding value of the index in the vector.
+ *
+ * @param dataset The dataset to transform and explode.
+ * @param explodeCols The alias for the exploded columns, must be a seq of two strings.
+ * @return A dataset containing idCol, inputCol and explodeCols
+ */
+ private[this] def processDataset(
+ dataset: Dataset[_],
+ inputName: String,
+ explodeCols: Seq[String]): Dataset[_] = {
+ require(explodeCols.size == 2, "explodeCols must be two strings.")
+ val vectorToMap = udf((x: Vector) => x.asBreeze.iterator.toMap,
+ MapType(DataTypes.IntegerType, DataTypes.DoubleType))
+ val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
+ transform(dataset)
+ } else {
+ dataset.toDF()
+ }
+ modelDataset.select(
+ struct(col("*")).as(inputName),
+ explode(vectorToMap(col($(outputCol)))).as(explodeCols))
+ }
+
+ /**
+ * Recreate a column using the same column name but different attribute id. Used in approximate
+ * similarity join.
+ * @param dataset The dataset where a column need to recreate
+ * @param colName The name of the column to recreate
+ * @param tmpColName A temporary column name which does not conflict with existing columns
+ * @return
+ */
+ private[this] def recreateCol(
+ dataset: Dataset[_],
+ colName: String,
+ tmpColName: String): Dataset[_] = {
+ dataset
+ .withColumnRenamed(colName, tmpColName)
+ .withColumn(colName, col(tmpColName))
+ .drop(tmpColName)
+ }
+
+ /**
+ * Join two dataset to approximately find all pairs of rows whose distance are smaller than
+ * the threshold. If the [[outputCol]] is missing, the method will transform the data; if the
+ * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed
+ * data when necessary.
+ *
+ * @param datasetA One of the datasets to join
+ * @param datasetB Another dataset to join
+ * @param threshold The threshold for the distance of row pairs
+ * @param distCol Output column for storing the distance between each result row and the key
+ * @return A joined dataset containing pairs of rows. The original rows are in columns
+ * "datasetA" and "datasetB", and a distCol is added to show the distance of each pair
+ */
+ def approxSimilarityJoin(
+ datasetA: Dataset[_],
+ datasetB: Dataset[_],
+ threshold: Double,
+ distCol: String): Dataset[_] = {
+
+ val leftColName = "datasetA"
+ val rightColName = "datasetB"
+ val explodeCols = Seq("entry", "hashValue")
+ val explodedA = processDataset(datasetA, leftColName, explodeCols)
+
+ // If this is a self join, we need to recreate the inputCol of datasetB to avoid ambiguity.
+ // TODO: Remove recreateCol logic once SPARK-17154 is resolved.
+ val explodedB = if (datasetA != datasetB) {
+ processDataset(datasetB, rightColName, explodeCols)
+ } else {
+ val recreatedB = recreateCol(datasetB, $(inputCol), s"${$(inputCol)}#${Random.nextString(5)}")
+ processDataset(recreatedB, rightColName, explodeCols)
+ }
+
+ // Do a hash join on where the exploded hash values are equal.
+ val joinedDataset = explodedA.join(explodedB, explodeCols)
+ .drop(explodeCols: _*).distinct()
+
+ // Add a new column to store the distance of the two rows.
+ val distUDF = udf((x: Vector, y: Vector) => keyDistance(x, y), DataTypes.DoubleType)
+ val joinedDatasetWithDist = joinedDataset.select(col("*"),
+ distUDF(col(s"$leftColName.${$(inputCol)}"), col(s"$rightColName.${$(inputCol)}")).as(distCol)
+ )
+
+ // Filter the joined datasets where the distance are smaller than the threshold.
+ joinedDatasetWithDist.filter(col(distCol) < threshold)
+ }
+
+ /**
+ * Overloaded method for approxSimilarityJoin. Use "distCol" as default distCol.
+ */
+ def approxSimilarityJoin(
+ datasetA: Dataset[_],
+ datasetB: Dataset[_],
+ threshold: Double): Dataset[_] = {
+ approxSimilarityJoin(datasetA, datasetB, threshold, "distCol")
+ }
+}
+
+/**
+ * Locality Sensitive Hashing for different metrics space. Support basic transformation with a new
+ * hash column, approximate nearest neighbor search with a dataset and a key, and approximate
+ * similarity join of two datasets.
+ *
+ * This LSH class implements OR-amplification: more than 1 hash functions can be chosen, and each
+ * input vector are hashed by all hash functions. Two input vectors are defined to be in the same
+ * bucket as long as ANY one of the hash value matches.
+ *
+ * References:
+ * (1) Gionis, Aristides, Piotr Indyk, and Rajeev Motwani. "Similarity search in high dimensions
+ * via hashing." VLDB 7 Sep. 1999: 518-529.
+ * (2) Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint
+ * arXiv:1408.2927 (2014).
+ */
+private[ml] abstract class LSH[T <: LSHModel[T]]
+ extends Estimator[T] with LSHParams with DefaultParamsWritable {
+ self: Estimator[T] =>
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ def setOutputDim(value: Int): this.type = set(outputDim, value)
+
+ /**
+ * Validate and create a new instance of concrete LSHModel. Because different LSHModel may have
+ * different initial setting, developer needs to define how their LSHModel is created instead of
+ * using reflection in this abstract class.
+ * @param inputDim The dimension of the input dataset
+ * @return A new LSHModel instance without any params
+ */
+ protected[this] def createRawLSHModel(inputDim: Int): T
+
+ override def fit(dataset: Dataset[_]): T = {
+ transformSchema(dataset.schema, logging = true)
+ val inputDim = dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size
+ val model = createRawLSHModel(inputDim).setParent(this)
+ copyValues(model)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
new file mode 100644
index 0000000000..d9d0f32254
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.Random
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ *
+ * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is
+ * a perfect hash function:
+ * `h_i(x) = (x * k_i mod prime) mod numEntries`
+ * where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*`
+ *
+ * Reference:
+ * [[https://en.wikipedia.org/wiki/Perfect_hash_function Wikipedia on Perfect Hash Function]]
+ *
+ * @param numEntries The number of entries of the hash functions.
+ * @param randCoefficients An array of random coefficients, each used by one hash function.
+ */
+@Experimental
+@Since("2.1.0")
+class MinHashModel private[ml] (
+ override val uid: String,
+ @Since("2.1.0") val numEntries: Int,
+ @Since("2.1.0") val randCoefficients: Array[Int])
+ extends LSHModel[MinHashModel] {
+
+ @Since("2.1.0")
+ override protected[ml] val hashFunction: Vector => Vector = {
+ elems: Vector =>
+ require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
+ val elemsList = elems.toSparse.indices.toList
+ val hashValues = randCoefficients.map({ randCoefficient: Int =>
+ elemsList.map({elem: Int =>
+ (1 + elem) * randCoefficient.toLong % MinHash.prime % numEntries
+ }).min.toDouble
+ })
+ Vectors.dense(hashValues)
+ }
+
+ @Since("2.1.0")
+ override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
+ val xSet = x.toSparse.indices.toSet
+ val ySet = y.toSparse.indices.toSet
+ val intersectionSize = xSet.intersect(ySet).size.toDouble
+ val unionSize = xSet.size + ySet.size - intersectionSize
+ assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
+ 1 - intersectionSize / unionSize
+ }
+
+ @Since("2.1.0")
+ override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
+ // Since it's generated by hashing, it will be a pair of dense vectors.
+ x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
+ }
+
+ @Since("2.1.0")
+ override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+
+ @Since("2.1.0")
+ override def write: MLWriter = new MinHashModel.MinHashModelWriter(this)
+}
+
+/**
+ * :: Experimental ::
+ *
+ * LSH class for Jaccard distance.
+ *
+ * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
+ * `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])`
+ * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
+ * Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated
+ * as binary "1" values.
+ *
+ * References:
+ * [[https://en.wikipedia.org/wiki/MinHash Wikipedia on MinHash]]
+ */
+@Experimental
+@Since("2.1.0")
+class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed {
+
+
+ @Since("2.1.0")
+ override def setInputCol(value: String): this.type = super.setInputCol(value)
+
+ @Since("2.1.0")
+ override def setOutputCol(value: String): this.type = super.setOutputCol(value)
+
+ @Since("2.1.0")
+ override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
+
+ @Since("2.1.0")
+ def this() = {
+ this(Identifiable.randomUID("min hash"))
+ }
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("2.1.0")
+ override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = {
+ require(inputDim <= MinHash.prime / 2,
+ s"The input vector dimension $inputDim exceeds the threshold ${MinHash.prime / 2}.")
+ val rand = new Random($(seed))
+ val numEntry = inputDim * 2
+ val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1))
+ new MinHashModel(uid, numEntry, randCoofs)
+ }
+
+ @Since("2.1.0")
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ validateAndTransformSchema(schema)
+ }
+
+ @Since("2.1.0")
+ override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+}
+
+@Since("2.1.0")
+object MinHash extends DefaultParamsReadable[MinHash] {
+ // A large prime smaller than sqrt(2^63 − 1)
+ private[ml] val prime = 2038074743
+
+ @Since("2.1.0")
+ override def load(path: String): MinHash = super.load(path)
+}
+
+@Since("2.1.0")
+object MinHashModel extends MLReadable[MinHashModel] {
+
+ @Since("2.1.0")
+ override def read: MLReader[MinHashModel] = new MinHashModelReader
+
+ @Since("2.1.0")
+ override def load(path: String): MinHashModel = super.load(path)
+
+ private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter {
+
+ private case class Data(numEntries: Int, randCoefficients: Array[Int])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.numEntries, instance.randCoefficients)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class MinHashModelReader extends MLReader[MinHashModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[MinHashModel].getName
+
+ override def load(path: String): MinHashModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head()
+ val numEntries = data.getAs[Int](0)
+ val randCoefficients = data.getAs[Seq[Int]](1).toArray
+ val model = new MinHashModel(metadata.uid, numEntries, randCoefficients)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala
new file mode 100644
index 0000000000..1b524c6710
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.Random
+
+import breeze.linalg.normalize
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ *
+ * Params for [[RandomProjection]].
+ */
+private[ml] trait RandomProjectionParams extends Params {
+
+ /**
+ * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
+ * buckets will be `(max L2 norm of input vectors) / bucketLength`.
+ *
+ *
+ * If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a
+ * reasonable value
+ * @group param
+ */
+ val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength",
+ "the length of each hash bucket, a larger bucket lowers the false negative rate.",
+ ParamValidators.gt(0))
+
+ /** @group getParam */
+ final def getBucketLength: Double = $(bucketLength)
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Model produced by [[RandomProjection]], where multiple random vectors are stored. The vectors
+ * are normalized to be unit vectors and each vector is used in a hash function:
+ * `h_i(x) = floor(r_i.dot(x) / bucketLength)`
+ * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
+ * vectors) / bucketLength`.
+ *
+ * @param randUnitVectors An array of random unit vectors. Each vector represents a hash function.
+ */
+@Experimental
+@Since("2.1.0")
+class RandomProjectionModel private[ml] (
+ override val uid: String,
+ @Since("2.1.0") val randUnitVectors: Array[Vector])
+ extends LSHModel[RandomProjectionModel] with RandomProjectionParams {
+
+ @Since("2.1.0")
+ override protected[ml] val hashFunction: (Vector) => Vector = {
+ key: Vector => {
+ val hashValues: Array[Double] = randUnitVectors.map({
+ randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
+ })
+ Vectors.dense(hashValues)
+ }
+ }
+
+ @Since("2.1.0")
+ override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
+ Math.sqrt(Vectors.sqdist(x, y))
+ }
+
+ @Since("2.1.0")
+ override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
+ // Since it's generated by hashing, it will be a pair of dense vectors.
+ x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
+ }
+
+ @Since("2.1.0")
+ override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+
+ @Since("2.1.0")
+ override def write: MLWriter = new RandomProjectionModel.RandomProjectionModelWriter(this)
+}
+
+/**
+ * :: Experimental ::
+ *
+ * This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean
+ * distance metrics.
+ *
+ * The input is dense or sparse vectors, each of which represents a point in the Euclidean
+ * distance space. The output will be vectors of configurable dimension. Hash value in the same
+ * dimension is calculated by the same hash function.
+ *
+ * References:
+ *
+ * 1. [[https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions
+ * Wikipedia on Stable Distributions]]
+ *
+ * 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint
+ * arXiv:1408.2927 (2014).
+ */
+@Experimental
+@Since("2.1.0")
+class RandomProjection(override val uid: String) extends LSH[RandomProjectionModel]
+ with RandomProjectionParams with HasSeed {
+
+ @Since("2.1.0")
+ override def setInputCol(value: String): this.type = super.setInputCol(value)
+
+ @Since("2.1.0")
+ override def setOutputCol(value: String): this.type = super.setOutputCol(value)
+
+ @Since("2.1.0")
+ override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
+
+ @Since("2.1.0")
+ def this() = {
+ this(Identifiable.randomUID("random projection"))
+ }
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setBucketLength(value: Double): this.type = set(bucketLength, value)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("2.1.0")
+ override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = {
+ val rand = new Random($(seed))
+ val randUnitVectors: Array[Vector] = {
+ Array.fill($(outputDim)) {
+ val randArray = Array.fill(inputDim)(rand.nextGaussian())
+ Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
+ }
+ }
+ new RandomProjectionModel(uid, randUnitVectors)
+ }
+
+ @Since("2.1.0")
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ validateAndTransformSchema(schema)
+ }
+
+ @Since("2.1.0")
+ override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+}
+
+@Since("2.1.0")
+object RandomProjection extends DefaultParamsReadable[RandomProjection] {
+
+ @Since("2.1.0")
+ override def load(path: String): RandomProjection = super.load(path)
+}
+
+@Since("2.1.0")
+object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
+
+ @Since("2.1.0")
+ override def read: MLReader[RandomProjectionModel] = new RandomProjectionModelReader
+
+ @Since("2.1.0")
+ override def load(path: String): RandomProjectionModel = super.load(path)
+
+ private[RandomProjectionModel] class RandomProjectionModelWriter(instance: RandomProjectionModel)
+ extends MLWriter {
+
+ // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
+ private case class Data(randUnitVectors: Matrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val numRows = instance.randUnitVectors.length
+ require(numRows > 0)
+ val numCols = instance.randUnitVectors.head.size
+ val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _))
+ val randMatrix = Matrices.dense(numRows, numCols, values)
+ val data = Data(randMatrix)
+ val dataPath = new Path(path, "data").toString
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class RandomProjectionModelReader extends MLReader[RandomProjectionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomProjectionModel].getName
+
+ override def load(path: String): RandomProjectionModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sparkSession.read.parquet(dataPath)
+ val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
+ .select("randUnitVectors")
+ .head()
+ val model = new RandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
new file mode 100644
index 0000000000..5c025546f3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DataTypes
+
+private[ml] object LSHTest {
+ /**
+ * For any locality sensitive function h in a metric space, we meed to verify whether
+ * the following property is satisfied.
+ *
+ * There exist dist1, dist2, p1, p2, so that for any two elements e1 and e2,
+ * If dist(e1, e2) <= dist1, then Pr{h(x) == h(y)} >= p1
+ * If dist(e1, e2) >= dist2, then Pr{h(x) == h(y)} <= p2
+ *
+ * This is called locality sensitive property. This method checks the property on an
+ * existing dataset and calculate the probabilities.
+ * (https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Definition)
+ *
+ * This method hashes each elements to hash buckets using LSH, and calculate the false positive
+ * and false negative:
+ * False positive: Of all (e1, e2) sharing any bucket, the probability of dist(e1, e2) > distFP
+ * False negative: Of all (e1, e2) not sharing buckets, the probability of dist(e1, e2) < distFN
+ *
+ * @param dataset The dataset to verify the locality sensitive hashing property.
+ * @param lsh The lsh instance to perform the hashing
+ * @param distFP Distance threshold for false positive
+ * @param distFN Distance threshold for false negative
+ * @tparam T The class type of lsh
+ * @return A tuple of two doubles, representing the false positive and false negative rate
+ */
+ def calculateLSHProperty[T <: LSHModel[T]](
+ dataset: Dataset[_],
+ lsh: LSH[T],
+ distFP: Double,
+ distFN: Double): (Double, Double) = {
+ val model = lsh.fit(dataset)
+ val inputCol = model.getInputCol
+ val outputCol = model.getOutputCol
+ val transformedData = model.transform(dataset)
+
+ SchemaUtils.checkColumnType(transformedData.schema, model.getOutputCol, new VectorUDT)
+
+ // Perform a cross join and label each pair of same_bucket and distance
+ val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
+ val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
+ val sameBucket = udf((x: Vector, y: Vector) => model.hashDistance(x, y) == 0.0,
+ DataTypes.BooleanType)
+ val result = pairs
+ .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
+ .withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))
+
+ // Compute the probabilities based on the join result
+ val positive = result.filter(col("same_bucket"))
+ val negative = result.filter(!col("same_bucket"))
+ val falsePositiveCount = positive.filter(col("distance") > distFP).count().toDouble
+ val falseNegativeCount = negative.filter(col("distance") < distFN).count().toDouble
+ (falsePositiveCount / positive.count(), falseNegativeCount / negative.count())
+ }
+
+ /**
+ * Compute the precision and recall of approximate nearest neighbors
+ * @param lsh The lsh instance
+ * @param dataset the dataset to look for the key
+ * @param key The key to hash for the item
+ * @param k The maximum number of items closest to the key
+ * @tparam T The class type of lsh
+ * @return A tuple of two doubles, representing precision and recall rate
+ */
+ def calculateApproxNearestNeighbors[T <: LSHModel[T]](
+ lsh: LSH[T],
+ dataset: Dataset[_],
+ key: Vector,
+ k: Int,
+ singleProbing: Boolean): (Double, Double) = {
+ val model = lsh.fit(dataset)
+
+ // Compute expected
+ val distUDF = udf((x: Vector) => model.keyDistance(x, key), DataTypes.DoubleType)
+ val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)
+
+ // Compute actual
+ val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol")
+
+ assert(actual.schema.sameType(model
+ .transformSchema(dataset.schema)
+ .add("distCol", DataTypes.DoubleType))
+ )
+
+ if (!singleProbing) {
+ assert(actual.count() == k)
+ }
+
+ // Compute precision and recall
+ val correctCount = expected.join(actual, model.getInputCol).count().toDouble
+ (correctCount / actual.count(), correctCount / expected.count())
+ }
+
+ /**
+ * Compute the precision and recall of approximate similarity join
+ * @param lsh The lsh instance
+ * @param datasetA One of the datasets to join
+ * @param datasetB Another dataset to join
+ * @param threshold The threshold for the distance of record pairs
+ * @tparam T The class type of lsh
+ * @return A tuple of two doubles, representing precision and recall rate
+ */
+ def calculateApproxSimilarityJoin[T <: LSHModel[T]](
+ lsh: LSH[T],
+ datasetA: Dataset[_],
+ datasetB: Dataset[_],
+ threshold: Double): (Double, Double) = {
+ val model = lsh.fit(datasetA)
+ val inputCol = model.getInputCol
+
+ // Compute expected
+ val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
+ val expected = datasetA.as("a").crossJoin(datasetB.as("b"))
+ .filter(distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")) < threshold)
+
+ // Compute actual
+ val actual = model.approxSimilarityJoin(datasetA, datasetB, threshold)
+
+ SchemaUtils.checkColumnType(actual.schema, "distCol", DataTypes.DoubleType)
+ assert(actual.schema.apply("datasetA").dataType
+ .sameType(model.transformSchema(datasetA.schema)))
+ assert(actual.schema.apply("datasetB").dataType
+ .sameType(model.transformSchema(datasetB.schema)))
+
+ // Compute precision and recall
+ val correctCount = actual.filter(col("distCol") < threshold).count().toDouble
+ (correctCount / actual.count(), correctCount / expected.count())
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala
new file mode 100644
index 0000000000..c32ca7d69c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ @transient var dataset: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val data = {
+ for (i <- 0 to 95) yield Vectors.sparse(100, (i until i + 5).map((_, 1.0)))
+ }
+ dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new MinHash)
+ val model = new MinHashModel("mh", numEntries = 2, randCoefficients = Array(1))
+ ParamsSuite.checkParams(model)
+ }
+
+ test("MinHash: default params") {
+ val rp = new MinHash
+ assert(rp.getOutputDim === 1.0)
+ }
+
+ test("read/write") {
+ def checkModelData(model: MinHashModel, model2: MinHashModel): Unit = {
+ assert(model.numEntries === model2.numEntries)
+ assertResult(model.randCoefficients)(model2.randCoefficients)
+ }
+ val mh = new MinHash()
+ val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
+ testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+ }
+
+ test("hashFunction") {
+ val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(0, 1, 3))
+ val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
+ assert(res.equals(Vectors.dense(0.0, 3.0, 4.0)))
+ }
+
+ test("keyDistance and hashDistance") {
+ val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(1))
+ val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))
+ val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0)))
+ val keyDist = model.keyDistance(v1, v2)
+ val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
+ assert(keyDist === 0.5)
+ assert(hashDist === 3)
+ }
+
+ test("MinHash: test of LSH property") {
+ val mh = new MinHash()
+ .setOutputDim(1)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setSeed(12344)
+
+ val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, mh, 0.75, 0.5)
+ assert(falsePositive < 0.3)
+ assert(falseNegative < 0.3)
+ }
+
+ test("approxNearestNeighbors for min hash") {
+ val mh = new MinHash()
+ .setOutputDim(20)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setSeed(12345)
+
+ val key: Vector = Vectors.sparse(100,
+ (0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))
+
+ val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20,
+ singleProbing = true)
+ assert(precision >= 0.7)
+ assert(recall >= 0.7)
+ }
+
+ test("approxSimilarityJoin for minhash on different dataset") {
+ val data1 = {
+ for (i <- 0 until 20) yield Vectors.sparse(100, (5 * i until 5 * i + 5).map((_, 1.0)))
+ }
+ val df1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys")
+
+ val data2 = {
+ for (i <- 0 until 30) yield Vectors.sparse(100, (3 * i until 3 * i + 3).map((_, 1.0)))
+ }
+ val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
+
+ val mh = new MinHash()
+ .setOutputDim(20)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setSeed(12345)
+
+ val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(mh, df1, df2, 0.5)
+ assert(precision == 1.0)
+ assert(recall >= 0.7)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala
new file mode 100644
index 0000000000..cd82ee2117
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import breeze.numerics.{cos, sin}
+import breeze.numerics.constants.Pi
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+class RandomProjectionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ @transient var dataset: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val data = {
+ for (i <- -10 until 10; j <- -10 until 10) yield Vectors.dense(i.toDouble, j.toDouble)
+ }
+ dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new RandomProjection)
+ val model = new RandomProjectionModel("rp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
+ ParamsSuite.checkParams(model)
+ }
+
+ test("RandomProjection: default params") {
+ val rp = new RandomProjection
+ assert(rp.getOutputDim === 1.0)
+ }
+
+ test("read/write") {
+ def checkModelData(model: RandomProjectionModel, model2: RandomProjectionModel): Unit = {
+ model.randUnitVectors.zip(model2.randUnitVectors)
+ .foreach(pair => assert(pair._1 === pair._2))
+ }
+ val mh = new RandomProjection()
+ val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
+ testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+ }
+
+ test("hashFunction") {
+ val randUnitVectors = Array(Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0))
+ val model = new RandomProjectionModel("rp", randUnitVectors)
+ model.set(model.bucketLength, 0.5)
+ val res = model.hashFunction(Vectors.dense(1.23, 4.56))
+ assert(res.equals(Vectors.dense(9.0, 2.0)))
+ }
+
+ test("keyDistance and hashDistance") {
+ val model = new RandomProjectionModel("rp", Array(Vectors.dense(0.0, 1.0)))
+ val keyDist = model.keyDistance(Vectors.dense(1, 2), Vectors.dense(-2, -2))
+ val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
+ assert(keyDist === 5)
+ assert(hashDist === 3)
+ }
+
+ test("RandomProjection: randUnitVectors") {
+ val rp = new RandomProjection()
+ .setOutputDim(20)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(1.0)
+ .setSeed(12345)
+ val unitVectors = rp.fit(dataset).randUnitVectors
+ unitVectors.foreach { v: Vector =>
+ assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
+ }
+ }
+
+ test("RandomProjection: test of LSH property") {
+ // Project from 2 dimensional Euclidean Space to 1 dimensions
+ val rp = new RandomProjection()
+ .setOutputDim(1)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(1.0)
+ .setSeed(12345)
+
+ val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, rp, 8.0, 2.0)
+ assert(falsePositive < 0.4)
+ assert(falseNegative < 0.4)
+ }
+
+ test("RandomProjection with high dimension data: test of LSH property") {
+ val numDim = 100
+ val data = {
+ for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2))
+ yield Vectors.sparse(numDim, Seq((i, j.toDouble)))
+ }
+ val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+
+ // Project from 100 dimensional Euclidean Space to 10 dimensions
+ val rp = new RandomProjection()
+ .setOutputDim(10)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(2.5)
+ .setSeed(12345)
+
+ val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, rp, 3.0, 2.0)
+ assert(falsePositive < 0.3)
+ assert(falseNegative < 0.3)
+ }
+
+ test("approxNearestNeighbors for random projection") {
+ val key = Vectors.dense(1.2, 3.4)
+
+ val rp = new RandomProjection()
+ .setOutputDim(2)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(4.0)
+ .setSeed(12345)
+
+ val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
+ singleProbing = true)
+ assert(precision >= 0.6)
+ assert(recall >= 0.6)
+ }
+
+ test("approxNearestNeighbors with multiple probing") {
+ val key = Vectors.dense(1.2, 3.4)
+
+ val rp = new RandomProjection()
+ .setOutputDim(20)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(1.0)
+ .setSeed(12345)
+
+ val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
+ singleProbing = false)
+ assert(precision >= 0.7)
+ assert(recall >= 0.7)
+ }
+
+ test("approxSimilarityJoin for random projection on different dataset") {
+ val data2 = {
+ for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
+ }
+ val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
+
+ val rp = new RandomProjection()
+ .setOutputDim(2)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(4.0)
+ .setSeed(12345)
+
+ val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, dataset, dataset2, 1.0)
+ assert(precision == 1.0)
+ assert(recall >= 0.7)
+ }
+
+ test("approxSimilarityJoin for self join") {
+ val data = {
+ for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
+ }
+ val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+
+ val rp = new RandomProjection()
+ .setOutputDim(2)
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setBucketLength(4.0)
+ .setSeed(12345)
+
+ val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, df, df, 3.0)
+ assert(precision == 1.0)
+ assert(recall >= 0.7)
+ }
+}