aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYun Ni <yunn@uber.com>2016-11-28 15:14:46 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-11-28 15:14:46 -0800
commit05f7c6ffab2a6be548375cd624dc27092677232f (patch)
tree27a954222f507a44273df13222d0946a7b485eed /mllib
parent8b1609bebe489b2ef78db4be6e9836687089fe3d (diff)
downloadspark-05f7c6ffab2a6be548375cd624dc27092677232f.tar.gz
spark-05f7c6ffab2a6be548375cd624dc27092677232f.tar.bz2
spark-05f7c6ffab2a6be548375cd624dc27092677232f.zip
[SPARK-18408][ML] API Improvements for LSH
## What changes were proposed in this pull request? (1) Change output schema to `Array of Vector` instead of `Vectors` (2) Use `numHashTables` as the dimension of Array (3) Rename `RandomProjection` to `BucketedRandomProjectionLSH`, `MinHash` to `MinHashLSH` (4) Make `randUnitVectors/randCoefficients` private (5) Make Multi-Probe NN Search and `hashDistance` private for future discussion Saved for future PRs: (1) AND-amplification and `numHashFunctions` as the dimension of Vector are saved for a future PR. (2) `hashDistance` and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one. ## How was this patch tested? Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation. Author: Yun Ni <yunn@uber.com> Author: Yunni <Euler57721@gmail.com> Closes #15874 from Yunni/SPARK-18408-yunn-api-improvements.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala (renamed from mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala)77
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala138
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala (renamed from mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala)112
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala)100
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala)83
6 files changed, 306 insertions, 221 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 2bff59a0da..cbac16345a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -34,9 +34,9 @@ import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
*
- * Params for [[RandomProjection]].
+ * Params for [[BucketedRandomProjectionLSH]].
*/
-private[ml] trait RandomProjectionParams extends Params {
+private[ml] trait BucketedRandomProjectionLSHParams extends Params {
/**
* The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
@@ -58,8 +58,8 @@ private[ml] trait RandomProjectionParams extends Params {
/**
* :: Experimental ::
*
- * Model produced by [[RandomProjection]], where multiple random vectors are stored. The vectors
- * are normalized to be unit vectors and each vector is used in a hash function:
+ * Model produced by [[BucketedRandomProjectionLSH]], where multiple random vectors are stored. The
+ * vectors are normalized to be unit vectors and each vector is used in a hash function:
* `h_i(x) = floor(r_i.dot(x) / bucketLength)`
* where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
* vectors) / bucketLength`.
@@ -68,18 +68,19 @@ private[ml] trait RandomProjectionParams extends Params {
*/
@Experimental
@Since("2.1.0")
-class RandomProjectionModel private[ml] (
+class BucketedRandomProjectionLSHModel private[ml](
override val uid: String,
- @Since("2.1.0") val randUnitVectors: Array[Vector])
- extends LSHModel[RandomProjectionModel] with RandomProjectionParams {
+ private[ml] val randUnitVectors: Array[Vector])
+ extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
@Since("2.1.0")
- override protected[ml] val hashFunction: (Vector) => Vector = {
+ override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
val hashValues: Array[Double] = randUnitVectors.map({
randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
})
- Vectors.dense(hashValues)
+ // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
+ hashValues.map(Vectors.dense(_))
}
}
@@ -89,27 +90,29 @@ class RandomProjectionModel private[ml] (
}
@Since("2.1.0")
- override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
+ override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
- x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
+ x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
}
@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
@Since("2.1.0")
- override def write: MLWriter = new RandomProjectionModel.RandomProjectionModelWriter(this)
+ override def write: MLWriter = {
+ new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
+ }
}
/**
* :: Experimental ::
*
- * This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean
- * distance metrics.
+ * This [[BucketedRandomProjectionLSH]] implements Locality Sensitive Hashing functions for
+ * Euclidean distance metrics.
*
* The input is dense or sparse vectors, each of which represents a point in the Euclidean
- * distance space. The output will be vectors of configurable dimension. Hash value in the same
- * dimension is calculated by the same hash function.
+ * distance space. The output will be vectors of configurable dimension. Hash values in the
+ * same dimension are calculated by the same hash function.
*
* References:
*
@@ -121,8 +124,9 @@ class RandomProjectionModel private[ml] (
*/
@Experimental
@Since("2.1.0")
-class RandomProjection(override val uid: String) extends LSH[RandomProjectionModel]
- with RandomProjectionParams with HasSeed {
+class BucketedRandomProjectionLSH(override val uid: String)
+ extends LSH[BucketedRandomProjectionLSHModel]
+ with BucketedRandomProjectionLSHParams with HasSeed {
@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
@@ -131,11 +135,11 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
override def setOutputCol(value: String): this.type = super.setOutputCol(value)
@Since("2.1.0")
- override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
+ override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)
@Since("2.1.0")
def this() = {
- this(Identifiable.randomUID("random projection"))
+ this(Identifiable.randomUID("brp-lsh"))
}
/** @group setParam */
@@ -147,15 +151,16 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
def setSeed(value: Long): this.type = set(seed, value)
@Since("2.1.0")
- override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = {
+ override protected[this] def createRawLSHModel(
+ inputDim: Int): BucketedRandomProjectionLSHModel = {
val rand = new Random($(seed))
val randUnitVectors: Array[Vector] = {
- Array.fill($(outputDim)) {
+ Array.fill($(numHashTables)) {
val randArray = Array.fill(inputDim)(rand.nextGaussian())
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
}
}
- new RandomProjectionModel(uid, randUnitVectors)
+ new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
}
@Since("2.1.0")
@@ -169,23 +174,25 @@ class RandomProjection(override val uid: String) extends LSH[RandomProjectionMod
}
@Since("2.1.0")
-object RandomProjection extends DefaultParamsReadable[RandomProjection] {
+object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomProjectionLSH] {
@Since("2.1.0")
- override def load(path: String): RandomProjection = super.load(path)
+ override def load(path: String): BucketedRandomProjectionLSH = super.load(path)
}
@Since("2.1.0")
-object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
+object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {
@Since("2.1.0")
- override def read: MLReader[RandomProjectionModel] = new RandomProjectionModelReader
+ override def read: MLReader[BucketedRandomProjectionLSHModel] = {
+ new BucketedRandomProjectionLSHModelReader
+ }
@Since("2.1.0")
- override def load(path: String): RandomProjectionModel = super.load(path)
+ override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)
- private[RandomProjectionModel] class RandomProjectionModelWriter(instance: RandomProjectionModel)
- extends MLWriter {
+ private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
+ instance: BucketedRandomProjectionLSHModel) extends MLWriter {
// TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
private case class Data(randUnitVectors: Matrix)
@@ -203,12 +210,13 @@ object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
}
}
- private class RandomProjectionModelReader extends MLReader[RandomProjectionModel] {
+ private class BucketedRandomProjectionLSHModelReader
+ extends MLReader[BucketedRandomProjectionLSHModel] {
/** Checked against metadata when loading model */
- private val className = classOf[RandomProjectionModel].getName
+ private val className = classOf[BucketedRandomProjectionLSHModel].getName
- override def load(path: String): RandomProjectionModel = {
+ override def load(path: String): BucketedRandomProjectionLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
@@ -216,7 +224,8 @@ object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
.select("randUnitVectors")
.head()
- val model = new RandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
+ val model = new BucketedRandomProjectionLSHModel(metadata.uid,
+ randUnitVectors.rowIter.toArray)
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index eb117c40ee..309cc2ef52 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -33,28 +33,28 @@ import org.apache.spark.sql.types._
*/
private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
/**
- * Param for the dimension of LSH OR-amplification.
+ * Param for the number of hash tables used in LSH OR-amplification.
*
- * In this implementation, we use LSH OR-amplification to reduce the false negative rate. The
- * higher the dimension is, the lower the false negative rate.
+ * LSH OR-amplification can be used to reduce the false negative rate. Higher values for this
+ * param lead to a reduced false negative rate, at the expense of added computational complexity.
* @group param
*/
- final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" +
- " increasing dimensionality lowers the false negative rate, and decreasing dimensionality" +
- " improves the running performance", ParamValidators.gt(0))
+ final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " +
+ "tables, where increasing number of hash tables lowers the false negative rate, and " +
+ "decreasing it improves the running performance", ParamValidators.gt(0))
/** @group getParam */
- final def getOutputDim: Int = $(outputDim)
+ final def getNumHashTables: Int = $(numHashTables)
- setDefault(outputDim -> 1)
+ setDefault(numHashTables -> 1)
/**
* Transform the Schema for LSH
- * @param schema The schema of the input dataset without [[outputCol]]
- * @return A derived schema with [[outputCol]] added
+ * @param schema The schema of the input dataset without [[outputCol]].
+ * @return A derived schema with [[outputCol]] added.
*/
protected[this] final def validateAndTransformSchema(schema: StructType): StructType = {
- SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(outputCol), DataTypes.createArrayType(new VectorUDT))
}
}
@@ -66,32 +66,32 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
self: T =>
/**
- * The hash function of LSH, mapping a predefined KeyType to a Vector
+ * The hash function of LSH, mapping an input feature vector to multiple hash vectors.
* @return The mapping of LSH function.
*/
- protected[ml] val hashFunction: Vector => Vector
+ protected[ml] val hashFunction: Vector => Array[Vector]
/**
* Calculate the distance between two different keys using the distance metric corresponding
- * to the hashFunction
- * @param x One input vector in the metric space
- * @param y One input vector in the metric space
- * @return The distance between x and y
+ * to the hashFunction.
+ * @param x One input vector in the metric space.
+ * @param y One input vector in the metric space.
+ * @return The distance between x and y.
*/
protected[ml] def keyDistance(x: Vector, y: Vector): Double
/**
* Calculate the distance between two different hash Vectors.
*
- * @param x One of the hash vector
- * @param y Another hash vector
- * @return The distance between hash vectors x and y
+ * @param x One of the hash vector.
+ * @param y Another hash vector.
+ * @return The distance between hash vectors x and y.
*/
- protected[ml] def hashDistance(x: Vector, y: Vector): Double
+ protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val transformUDF = udf(hashFunction, new VectorUDT)
+ val transformUDF = udf(hashFunction, DataTypes.createArrayType(new VectorUDT))
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
}
@@ -99,29 +99,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
validateAndTransformSchema(schema)
}
- /**
- * Given a large dataset and an item, approximately find at most k items which have the closest
- * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
- * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
- * transformed data when necessary.
- *
- * This method implements two ways of fetching k nearest neighbors:
- * - Single Probing: Fast, return at most k elements (Probing only one buckets)
- * - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key)
- *
- * @param dataset the dataset to search for nearest neighbors of the key
- * @param key Feature vector representing the item to search for
- * @param numNearestNeighbors The maximum number of nearest neighbors
- * @param singleProbing True for using Single Probing; false for multiple probing
- * @param distCol Output column for storing the distance between each result row and the key
- * @return A dataset containing at most k items closest to the key. A distCol is added to show
- * the distance between each row and the key.
- */
- def approxNearestNeighbors(
+ // TODO: Fix the MultiProbe NN Search in SPARK-18454
+ private[feature] def approxNearestNeighbors(
dataset: Dataset[_],
key: Vector,
numNearestNeighbors: Int,
- singleProbing: Boolean,
+ singleProbe: Boolean,
distCol: String): Dataset[_] = {
require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
// Get Hash Value of the key
@@ -132,14 +115,24 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
dataset.toDF()
}
- // In the origin dataset, find the hash value that is closest to the key
- val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType)
- val hashDistCol = hashDistUDF(col($(outputCol)))
+ val modelSubset = if (singleProbe) {
+ def sameBucket(x: Seq[Vector], y: Seq[Vector]): Boolean = {
+ x.zip(y).exists(tuple => tuple._1 == tuple._2)
+ }
+
+ // In the origin dataset, find the hash value that hash the same bucket with the key
+ val sameBucketWithKeyUDF = udf((x: Seq[Vector]) =>
+ sameBucket(x, keyHash), DataTypes.BooleanType)
- val modelSubset = if (singleProbing) {
- modelDataset.filter(hashDistCol === 0.0)
+ modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
} else {
+ // In the origin dataset, find the hash value that is closest to the key
+ // Limit the use of hashDist since it's controversial
+ val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
+ val hashDistCol = hashDistUDF(col($(outputCol)))
+
// Compute threshold to get exact k elements.
+ // TODO: SPARK-18409: Use approxQuantile to get the threshold
val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
@@ -155,8 +148,30 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
}
/**
- * Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search
- * nearest neighbors and "distCol" as default distCol.
+ * Given a large dataset and an item, approximately find at most k items which have the closest
+ * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
+ * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
+ * transformed data when necessary.
+ *
+ * @note This method is experimental and will likely change behavior in the next release.
+ *
+ * @param dataset The dataset to search for nearest neighbors of the key.
+ * @param key Feature vector representing the item to search for.
+ * @param numNearestNeighbors The maximum number of nearest neighbors.
+ * @param distCol Output column for storing the distance between each result row and the key.
+ * @return A dataset containing at most k items closest to the key. A column "distCol" is added
+ * to show the distance between each row and the key.
+ */
+ def approxNearestNeighbors(
+ dataset: Dataset[_],
+ key: Vector,
+ numNearestNeighbors: Int,
+ distCol: String): Dataset[_] = {
+ approxNearestNeighbors(dataset, key, numNearestNeighbors, true, distCol)
+ }
+
+ /**
+ * Overloaded method for approxNearestNeighbors. Use "distCol" as default distCol.
*/
def approxNearestNeighbors(
dataset: Dataset[_],
@@ -172,31 +187,28 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
*
* @param dataset The dataset to transform and explode.
* @param explodeCols The alias for the exploded columns, must be a seq of two strings.
- * @return A dataset containing idCol, inputCol and explodeCols
+ * @return A dataset containing idCol, inputCol and explodeCols.
*/
private[this] def processDataset(
dataset: Dataset[_],
inputName: String,
explodeCols: Seq[String]): Dataset[_] = {
require(explodeCols.size == 2, "explodeCols must be two strings.")
- val vectorToMap = udf((x: Vector) => x.asBreeze.iterator.toMap,
- MapType(DataTypes.IntegerType, DataTypes.DoubleType))
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
transform(dataset)
} else {
dataset.toDF()
}
modelDataset.select(
- struct(col("*")).as(inputName),
- explode(vectorToMap(col($(outputCol)))).as(explodeCols))
+ struct(col("*")).as(inputName), posexplode(col($(outputCol))).as(explodeCols))
}
/**
* Recreate a column using the same column name but different attribute id. Used in approximate
* similarity join.
- * @param dataset The dataset where a column need to recreate
- * @param colName The name of the column to recreate
- * @param tmpColName A temporary column name which does not conflict with existing columns
+ * @param dataset The dataset where a column need to recreate.
+ * @param colName The name of the column to recreate.
+ * @param tmpColName A temporary column name which does not conflict with existing columns.
* @return
*/
private[this] def recreateCol(
@@ -215,12 +227,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
* [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed
* data when necessary.
*
- * @param datasetA One of the datasets to join
- * @param datasetB Another dataset to join
- * @param threshold The threshold for the distance of row pairs
- * @param distCol Output column for storing the distance between each result row and the key
+ * @param datasetA One of the datasets to join.
+ * @param datasetB Another dataset to join.
+ * @param threshold The threshold for the distance of row pairs.
+ * @param distCol Output column for storing the distance between each result row and the key.
* @return A joined dataset containing pairs of rows. The original rows are in columns
- * "datasetA" and "datasetB", and a distCol is added to show the distance of each pair
+ * "datasetA" and "datasetB", and a distCol is added to show the distance of each pair.
*/
def approxSimilarityJoin(
datasetA: Dataset[_],
@@ -293,7 +305,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]]
def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
- def setOutputDim(value: Int): this.type = set(outputDim, value)
+ def setNumHashTables(value: Int): this.type = set(numHashTables, value)
/**
* Validate and create a new instance of concrete LSHModel. Because different LSHModel may have
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index f37233e1ab..620e1fbb09 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -31,37 +31,39 @@ import org.apache.spark.sql.types.StructType
/**
* :: Experimental ::
*
- * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is
- * a perfect hash function:
- * `h_i(x) = (x * k_i mod prime) mod numEntries`
- * where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*`
+ * Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function
+ * is picked from the following family of hash functions, where a_i and b_i are randomly chosen
+ * integers less than prime:
+ * `h_i(x) = ((x \cdot a_i + b_i) \mod prime)`
+ *
+ * This hash family is approximately min-wise independent according to the reference.
*
* Reference:
- * <a href="https://en.wikipedia.org/wiki/Perfect_hash_function">
- * Wikipedia on Perfect Hash Function</a>
+ * Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear permutations."
+ * Electronic Journal of Combinatorics 7 (2000): R26.
*
- * @param numEntries The number of entries of the hash functions.
- * @param randCoefficients An array of random coefficients, each used by one hash function.
+ * @param randCoefficients Pairs of random coefficients. Each pair is used by one hash function.
*/
@Experimental
@Since("2.1.0")
-class MinHashModel private[ml] (
+class MinHashLSHModel private[ml](
override val uid: String,
- @Since("2.1.0") val numEntries: Int,
- @Since("2.1.0") val randCoefficients: Array[Int])
- extends LSHModel[MinHashModel] {
+ private[ml] val randCoefficients: Array[(Int, Int)])
+ extends LSHModel[MinHashLSHModel] {
@Since("2.1.0")
- override protected[ml] val hashFunction: Vector => Vector = {
- elems: Vector =>
+ override protected[ml] val hashFunction: Vector => Array[Vector] = {
+ elems: Vector => {
require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
val elemsList = elems.toSparse.indices.toList
- val hashValues = randCoefficients.map({ randCoefficient: Int =>
- elemsList.map({elem: Int =>
- (1 + elem) * randCoefficient.toLong % MinHash.prime % numEntries
- }).min.toDouble
- })
- Vectors.dense(hashValues)
+ val hashValues = randCoefficients.map { case (a, b) =>
+ elemsList.map { elem: Int =>
+ ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME
+ }.min.toDouble
+ }
+ // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
+ hashValues.map(Vectors.dense(_))
+ }
}
@Since("2.1.0")
@@ -75,16 +77,19 @@ class MinHashModel private[ml] (
}
@Since("2.1.0")
- override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
+ override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
- x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
+ // TODO: This hashDistance function requires more discussion in SPARK-18454
+ x.zip(y).map(vectorPair =>
+ vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
+ ).min
}
@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
@Since("2.1.0")
- override def write: MLWriter = new MinHashModel.MinHashModelWriter(this)
+ override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
}
/**
@@ -93,18 +98,17 @@ class MinHashModel private[ml] (
* LSH class for Jaccard distance.
*
* The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
- * `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])`
- * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
- * Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated
- * as binary "1" values.
+ * `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))`
+ * means there are 10 elements in the space. This set contains elements 2, 3, and 5. Also, any
+ * input vector must have at least 1 non-zero index, and all non-zero values are
+ * treated as binary "1" values.
*
* References:
* <a href="https://en.wikipedia.org/wiki/MinHash">Wikipedia on MinHash</a>
*/
@Experimental
@Since("2.1.0")
-class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed {
-
+class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with HasSeed {
@Since("2.1.0")
override def setInputCol(value: String): this.type = super.setInputCol(value)
@@ -113,11 +117,11 @@ class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed {
override def setOutputCol(value: String): this.type = super.setOutputCol(value)
@Since("2.1.0")
- override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
+ override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)
@Since("2.1.0")
def this() = {
- this(Identifiable.randomUID("min hash"))
+ this(Identifiable.randomUID("mh-lsh"))
}
/** @group setParam */
@@ -125,13 +129,14 @@ class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed {
def setSeed(value: Long): this.type = set(seed, value)
@Since("2.1.0")
- override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = {
- require(inputDim <= MinHash.prime / 2,
- s"The input vector dimension $inputDim exceeds the threshold ${MinHash.prime / 2}.")
+ override protected[ml] def createRawLSHModel(inputDim: Int): MinHashLSHModel = {
+ require(inputDim <= MinHashLSH.HASH_PRIME,
+ s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.")
val rand = new Random($(seed))
- val numEntry = inputDim * 2
- val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1))
- new MinHashModel(uid, numEntry, randCoofs)
+ val randCoefs: Array[(Int, Int)] = Array.fill($(numHashTables)) {
+ (1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1))
+ }
+ new MinHashLSHModel(uid, randCoefs)
}
@Since("2.1.0")
@@ -145,48 +150,49 @@ class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed {
}
@Since("2.1.0")
-object MinHash extends DefaultParamsReadable[MinHash] {
+object MinHashLSH extends DefaultParamsReadable[MinHashLSH] {
// A large prime smaller than sqrt(2^63 − 1)
- private[ml] val prime = 2038074743
+ private[ml] val HASH_PRIME = 2038074743
@Since("2.1.0")
- override def load(path: String): MinHash = super.load(path)
+ override def load(path: String): MinHashLSH = super.load(path)
}
@Since("2.1.0")
-object MinHashModel extends MLReadable[MinHashModel] {
+object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
@Since("2.1.0")
- override def read: MLReader[MinHashModel] = new MinHashModelReader
+ override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader
@Since("2.1.0")
- override def load(path: String): MinHashModel = super.load(path)
+ override def load(path: String): MinHashLSHModel = super.load(path)
- private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter {
+ private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel)
+ extends MLWriter {
- private case class Data(numEntries: Int, randCoefficients: Array[Int])
+ private case class Data(randCoefficients: Array[Int])
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
- val data = Data(instance.numEntries, instance.randCoefficients)
+ val data = Data(instance.randCoefficients.flatMap(tuple => Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
- private class MinHashModelReader extends MLReader[MinHashModel] {
+ private class MinHashLSHModelReader extends MLReader[MinHashLSHModel] {
/** Checked against metadata when loading model */
- private val className = classOf[MinHashModel].getName
+ private val className = classOf[MinHashLSHModel].getName
- override def load(path: String): MinHashModel = {
+ override def load(path: String): MinHashLSHModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head()
- val numEntries = data.getAs[Int](0)
- val randCoefficients = data.getAs[Seq[Int]](1).toArray
- val model = new MinHashModel(metadata.uid, numEntries, randCoefficients)
+ val data = sparkSession.read.parquet(dataPath).select("randCoefficients").head()
+ val randCoefficients = data.getAs[Seq[Int]](0).grouped(2)
+ .map(tuple => (tuple(0), tuple(1))).toArray
+ val model = new MinHashLSHModel(metadata.uid, randCoefficients)
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index cd82ee2117..ab937685a5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
-class RandomProjectionSuite
+class BucketedRandomProjectionLSHSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: Dataset[_] = _
@@ -43,70 +43,72 @@ class RandomProjectionSuite
}
test("params") {
- ParamsSuite.checkParams(new RandomProjection)
- val model = new RandomProjectionModel("rp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
+ ParamsSuite.checkParams(new BucketedRandomProjectionLSH)
+ val model = new BucketedRandomProjectionLSHModel(
+ "brp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
ParamsSuite.checkParams(model)
}
- test("RandomProjection: default params") {
- val rp = new RandomProjection
- assert(rp.getOutputDim === 1.0)
+ test("BucketedRandomProjectionLSH: default params") {
+ val brp = new BucketedRandomProjectionLSH
+ assert(brp.getNumHashTables === 1.0)
}
test("read/write") {
- def checkModelData(model: RandomProjectionModel, model2: RandomProjectionModel): Unit = {
+ def checkModelData(
+ model: BucketedRandomProjectionLSHModel,
+ model2: BucketedRandomProjectionLSHModel): Unit = {
model.randUnitVectors.zip(model2.randUnitVectors)
.foreach(pair => assert(pair._1 === pair._2))
}
- val mh = new RandomProjection()
+ val mh = new BucketedRandomProjectionLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
}
test("hashFunction") {
val randUnitVectors = Array(Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0))
- val model = new RandomProjectionModel("rp", randUnitVectors)
+ val model = new BucketedRandomProjectionLSHModel("brp", randUnitVectors)
model.set(model.bucketLength, 0.5)
val res = model.hashFunction(Vectors.dense(1.23, 4.56))
- assert(res.equals(Vectors.dense(9.0, 2.0)))
+ assert(res.length == 2)
+ assert(res(0).equals(Vectors.dense(9.0)))
+ assert(res(1).equals(Vectors.dense(2.0)))
}
- test("keyDistance and hashDistance") {
- val model = new RandomProjectionModel("rp", Array(Vectors.dense(0.0, 1.0)))
+ test("keyDistance") {
+ val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0)))
val keyDist = model.keyDistance(Vectors.dense(1, 2), Vectors.dense(-2, -2))
- val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
assert(keyDist === 5)
- assert(hashDist === 3)
}
- test("RandomProjection: randUnitVectors") {
- val rp = new RandomProjection()
- .setOutputDim(20)
+ test("BucketedRandomProjectionLSH: randUnitVectors") {
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(20)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(1.0)
.setSeed(12345)
- val unitVectors = rp.fit(dataset).randUnitVectors
+ val unitVectors = brp.fit(dataset).randUnitVectors
unitVectors.foreach { v: Vector =>
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
}
}
- test("RandomProjection: test of LSH property") {
+ test("BucketedRandomProjectionLSH: test of LSH property") {
// Project from 2 dimensional Euclidean Space to 1 dimensions
- val rp = new RandomProjection()
- .setOutputDim(1)
+ val brp = new BucketedRandomProjectionLSH()
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(1.0)
.setSeed(12345)
- val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, rp, 8.0, 2.0)
+ val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, brp, 8.0, 2.0)
assert(falsePositive < 0.4)
assert(falseNegative < 0.4)
}
- test("RandomProjection with high dimension data: test of LSH property") {
+ test("BucketedRandomProjectionLSH with high dimension data: test of LSH property") {
val numDim = 100
val data = {
for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2))
@@ -115,30 +117,30 @@ class RandomProjectionSuite
val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
// Project from 100 dimensional Euclidean Space to 10 dimensions
- val rp = new RandomProjection()
- .setOutputDim(10)
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(10)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(2.5)
.setSeed(12345)
- val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, rp, 3.0, 2.0)
+ val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, brp, 3.0, 2.0)
assert(falsePositive < 0.3)
assert(falseNegative < 0.3)
}
- test("approxNearestNeighbors for random projection") {
+ test("approxNearestNeighbors for bucketed random projection") {
val key = Vectors.dense(1.2, 3.4)
- val rp = new RandomProjection()
- .setOutputDim(2)
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(2)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(4.0)
.setSeed(12345)
- val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
- singleProbing = true)
+ val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(brp, dataset, key, 100,
+ singleProbe = true)
assert(precision >= 0.6)
assert(recall >= 0.6)
}
@@ -146,33 +148,47 @@ class RandomProjectionSuite
test("approxNearestNeighbors with multiple probing") {
val key = Vectors.dense(1.2, 3.4)
- val rp = new RandomProjection()
- .setOutputDim(20)
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(20)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(1.0)
.setSeed(12345)
- val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
- singleProbing = false)
+ val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(brp, dataset, key, 100,
+ singleProbe = false)
assert(precision >= 0.7)
assert(recall >= 0.7)
}
- test("approxSimilarityJoin for random projection on different dataset") {
+ test("approxNearestNeighbors for numNeighbors <= 0") {
+ val key = Vectors.dense(1.2, 3.4)
+
+ val model = new BucketedRandomProjectionLSHModel(
+ "brp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
+
+ intercept[IllegalArgumentException] {
+ model.approxNearestNeighbors(dataset, key, 0)
+ }
+ intercept[IllegalArgumentException] {
+ model.approxNearestNeighbors(dataset, key, -1)
+ }
+ }
+
+ test("approxSimilarityJoin for bucketed random projection on different dataset") {
val data2 = {
for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
}
val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
- val rp = new RandomProjection()
- .setOutputDim(2)
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(2)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(4.0)
.setSeed(12345)
- val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, dataset, dataset2, 1.0)
+ val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, dataset, dataset2, 1.0)
assert(precision == 1.0)
assert(recall >= 0.7)
}
@@ -183,14 +199,14 @@ class RandomProjectionSuite
}
val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
- val rp = new RandomProjection()
- .setOutputDim(2)
+ val brp = new BucketedRandomProjectionLSH()
+ .setNumHashTables(2)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(4.0)
.setSeed(12345)
- val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, df, df, 3.0)
+ val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, df, df, 3.0)
assert(precision == 1.0)
assert(recall >= 0.7)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
index 5c025546f3..a9b559f7ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -58,12 +58,18 @@ private[ml] object LSHTest {
val outputCol = model.getOutputCol
val transformedData = model.transform(dataset)
- SchemaUtils.checkColumnType(transformedData.schema, model.getOutputCol, new VectorUDT)
+ // Check output column type
+ SchemaUtils.checkColumnType(
+ transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT))
+
+ // Check output column dimensions
+ val headHashValue = transformedData.select(outputCol).head().get(0).asInstanceOf[Seq[Vector]]
+ assert(headHashValue.length == model.getNumHashTables)
// Perform a cross join and label each pair of same_bucket and distance
val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
- val sameBucket = udf((x: Vector, y: Vector) => model.hashDistance(x, y) == 0.0,
+ val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0,
DataTypes.BooleanType)
val result = pairs
.withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
@@ -83,6 +89,7 @@ private[ml] object LSHTest {
* @param dataset the dataset to look for the key
* @param key The key to hash for the item
* @param k The maximum number of items closest to the key
+ * @param singleProbe True for using single-probe; false for multi-probe
* @tparam T The class type of lsh
* @return A tuple of two doubles, representing precision and recall rate
*/
@@ -91,7 +98,7 @@ private[ml] object LSHTest {
dataset: Dataset[_],
key: Vector,
k: Int,
- singleProbing: Boolean): (Double, Double) = {
+ singleProbe: Boolean): (Double, Double) = {
val model = lsh.fit(dataset)
// Compute expected
@@ -99,14 +106,14 @@ private[ml] object LSHTest {
val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)
// Compute actual
- val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol")
+ val actual = model.approxNearestNeighbors(dataset, key, k, singleProbe, "distCol")
assert(actual.schema.sameType(model
.transformSchema(dataset.schema)
.add("distCol", DataTypes.DoubleType))
)
- if (!singleProbing) {
+ if (!singleProbe) {
assert(actual.count() == k)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index c32ca7d69c..3461cdf824 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
-class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: Dataset[_] = _
@@ -38,45 +38,51 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default
}
test("params") {
- ParamsSuite.checkParams(new MinHash)
- val model = new MinHashModel("mh", numEntries = 2, randCoefficients = Array(1))
+ ParamsSuite.checkParams(new MinHashLSH)
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
ParamsSuite.checkParams(model)
}
- test("MinHash: default params") {
- val rp = new MinHash
- assert(rp.getOutputDim === 1.0)
+ test("MinHashLSH: default params") {
+ val rp = new MinHashLSH
+ assert(rp.getNumHashTables === 1.0)
}
test("read/write") {
- def checkModelData(model: MinHashModel, model2: MinHashModel): Unit = {
- assert(model.numEntries === model2.numEntries)
+ def checkModelData(model: MinHashLSHModel, model2: MinHashLSHModel): Unit = {
assertResult(model.randCoefficients)(model2.randCoefficients)
}
- val mh = new MinHash()
+ val mh = new MinHashLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
}
test("hashFunction") {
- val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(0, 1, 3))
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
- assert(res.equals(Vectors.dense(0.0, 3.0, 4.0)))
+ assert(res.length == 3)
+ assert(res(0).equals(Vectors.dense(1.0)))
+ assert(res(1).equals(Vectors.dense(5.0)))
+ assert(res(2).equals(Vectors.dense(9.0)))
}
- test("keyDistance and hashDistance") {
- val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(1))
+ test("hashFunction: empty vector") {
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
+ intercept[IllegalArgumentException] {
+ model.hashFunction(Vectors.sparse(10, Seq()))
+ }
+ }
+
+ test("keyDistance") {
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))
val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0)))
val keyDist = model.keyDistance(v1, v2)
- val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
assert(keyDist === 0.5)
- assert(hashDist === 3)
}
- test("MinHash: test of LSH property") {
- val mh = new MinHash()
- .setOutputDim(1)
+ test("MinHashLSH: test of LSH property") {
+ val mh = new MinHashLSH()
.setInputCol("keys")
.setOutputCol("values")
.setSeed(12344)
@@ -86,9 +92,24 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default
assert(falseNegative < 0.3)
}
+ test("MinHashLSH: test of inputDim > prime") {
+ val mh = new MinHashLSH()
+ .setInputCol("keys")
+ .setOutputCol("values")
+ .setSeed(12344)
+
+ val data = {
+ for (i <- 0 to 2) yield Vectors.sparse(Int.MaxValue, (i until i + 5).map((_, 1.0)))
+ }
+ val badDataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+ intercept[IllegalArgumentException] {
+ mh.fit(badDataset)
+ }
+ }
+
test("approxNearestNeighbors for min hash") {
- val mh = new MinHash()
- .setOutputDim(20)
+ val mh = new MinHashLSH()
+ .setNumHashTables(20)
.setInputCol("keys")
.setOutputCol("values")
.setSeed(12345)
@@ -97,12 +118,26 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default
(0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))
val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20,
- singleProbing = true)
+ singleProbe = true)
assert(precision >= 0.7)
assert(recall >= 0.7)
}
- test("approxSimilarityJoin for minhash on different dataset") {
+ test("approxNearestNeighbors for numNeighbors <= 0") {
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+
+ val key: Vector = Vectors.sparse(100,
+ (0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))
+
+ intercept[IllegalArgumentException] {
+ model.approxNearestNeighbors(dataset, key, 0)
+ }
+ intercept[IllegalArgumentException] {
+ model.approxNearestNeighbors(dataset, key, -1)
+ }
+ }
+
+ test("approxSimilarityJoin for min hash on different dataset") {
val data1 = {
for (i <- 0 until 20) yield Vectors.sparse(100, (5 * i until 5 * i + 5).map((_, 1.0)))
}
@@ -113,8 +148,8 @@ class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with Default
}
val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
- val mh = new MinHash()
- .setOutputDim(20)
+ val mh = new MinHashLSH()
+ .setNumHashTables(20)
.setInputCol("keys")
.setOutputCol("values")
.setSeed(12345)