diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 88 |
1 files changed, 67 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index af0cfe22ca..34833e90d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -317,7 +325,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +379,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +409,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +430,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +459,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } } object DenseVector { @@ -522,8 +546,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -547,7 +571,7 @@ class SparseVector( private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +580,28 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } } object SparseVector { |