aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala4
3 files changed, 22 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 9d895b8fac..5d11ed0971 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.tree
+import java.util.Objects
+
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
@@ -112,12 +114,15 @@ final class CategoricalSplit private[ml] (
}
}
- override def equals(o: Any): Boolean = {
- o match {
- case other: CategoricalSplit => featureIndex == other.featureIndex &&
- isLeft == other.isLeft && categories == other.categories
- case _ => false
- }
+ override def hashCode(): Int = {
+ val state = Seq(featureIndex, isLeft, categories)
+ state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
+ }
+
+ override def equals(o: Any): Boolean = o match {
+ case other: CategoricalSplit => featureIndex == other.featureIndex &&
+ isLeft == other.isLeft && categories == other.categories
+ case _ => false
}
override private[tree] def toOld: OldSplit = {
@@ -181,6 +186,11 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
}
}
+ override def hashCode(): Int = {
+ val state = Seq(featureIndex, threshold)
+ state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
+ }
+
override private[tree] def toOld: OldSplit = {
OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index bb5d6d9d51..90fa4fbbc6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -606,6 +606,8 @@ class SparseMatrix @Since("1.3.0") (
case _ => false
}
+ override def hashCode(): Int = toBreeze.hashCode
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 5ec83e8d5c..6e3da6b701 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -628,6 +628,8 @@ class DenseVector @Since("1.0.0") (
}
}
+ override def equals(other: Any): Boolean = super.equals(other)
+
override def hashCode(): Int = {
var result: Int = 31 + size
var i = 0
@@ -775,6 +777,8 @@ class SparseVector @Since("1.0.0") (
}
}
+ override def equals(other: Any): Boolean = super.equals(other)
+
override def hashCode(): Int = {
var result: Int = 31 + size
val end = values.length