aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala38
1 files changed, 21 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 06ebb15869..3642e92865 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -71,20 +71,22 @@ sealed trait Vector extends Serializable {
}
/**
- * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros
- * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
+ * Returns a hash code value for the vector. The hash code is based on its size and its first 128
+ * nonzero entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
*/
override def hashCode(): Int = {
// This is a reference implementation. It calls return in foreachActive, which is slow.
// Subclasses should override it with optimized implementation.
var result: Int = 31 + size
+ var nnz = 0
this.foreachActive { (index, value) =>
- if (index < 16) {
+ if (nnz < Vectors.MAX_HASH_NNZ) {
// ignore explicit 0 for comparison between sparse and dense
if (value != 0) {
result = 31 * result + index
val bits = java.lang.Double.doubleToLongBits(value)
result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ nnz += 1
}
} else {
return result
@@ -536,6 +538,9 @@ object Vectors {
}
allEqual
}
+
+ /** Max number of nonzero entries used in computing hash code. */
+ private[linalg] val MAX_HASH_NNZ = 128
}
/**
@@ -578,13 +583,15 @@ class DenseVector @Since("1.0.0") (
override def hashCode(): Int = {
var result: Int = 31 + size
var i = 0
- val end = math.min(values.length, 16)
- while (i < end) {
+ val end = values.length
+ var nnz = 0
+ while (i < end && nnz < Vectors.MAX_HASH_NNZ) {
val v = values(i)
if (v != 0.0) {
result = 31 * result + i
val bits = java.lang.Double.doubleToLongBits(values(i))
result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ nnz += 1
}
i += 1
}
@@ -707,19 +714,16 @@ class SparseVector @Since("1.0.0") (
override def hashCode(): Int = {
var result: Int = 31 + size
val end = values.length
- var continue = true
var k = 0
- while ((k < end) & continue) {
- val i = indices(k)
- if (i < 16) {
- val v = values(k)
- if (v != 0.0) {
- result = 31 * result + i
- val bits = java.lang.Double.doubleToLongBits(v)
- result = 31 * result + (bits ^ (bits >>> 32)).toInt
- }
- } else {
- continue = false
+ var nnz = 0
+ while (k < end && nnz < Vectors.MAX_HASH_NNZ) {
+ val v = values(k)
+ if (v != 0.0) {
+ val i = indices(k)
+ result = 31 * result + i
+ val bits = java.lang.Double.doubleToLongBits(v)
+ result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ nnz += 1
}
k += 1
}