aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-03-20 14:13:02 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-20 14:13:02 -0400
commitdb4d317ccfdd9bd1dc7e8beac54ebcc35966b7d5 (patch)
treee6d6706f99b6564383b994750917858a0b1b1fc2 /mllib/src
parent6f80c3e8880340597f161f87e64697bec86cc586 (diff)
downloadspark-db4d317ccfdd9bd1dc7e8beac54ebcc35966b7d5.tar.gz
spark-db4d317ccfdd9bd1dc7e8beac54ebcc35966b7d5.tar.bz2
spark-db4d317ccfdd9bd1dc7e8beac54ebcc35966b7d5.zip
[SPARK-6428][MLlib] Added explicit type for public methods and implemented hashCode when equals is defined.
I want to add a checker to turn public type checking on, since future pull requests can accidentally expose a non-public type. This is the first cleanup task. Author: Reynold Xin <rxin@databricks.com> Closes #5102 from rxin/mllib-hashcode-publicmethodtypes and squashes the following commits: 617f19e [Reynold Xin] Fixed Scala compilation error. 52bc2d5 [Reynold Xin] [MLlib] Added explicit type for public methods and implemented hashCode when equals is defined.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala6
21 files changed, 97 insertions, 58 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 6131ba8832..fc4e12773c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -41,7 +41,7 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
def getNumFeatures: Int = get(numFeatures)
/** @group setParam */
- def setNumFeatures(value: Int) = set(numFeatures, value)
+ def setNumFeatures(value: Int): this.type = set(numFeatures, value)
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index cbd87ea8ae..15ca2547d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -345,9 +345,13 @@ private[python] class PythonMLLibAPI extends Serializable {
def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
predict(SerDe.asTupleRDD(userAndProducts.rdd))
- def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ def getUserFeatures: RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+ }
- def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ def getProductFeatures: RDD[Array[Any]] = {
+ SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+ }
}
@@ -909,7 +913,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for DenseVector
private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
val bytes = new Array[Byte](8 * vector.size)
val bb = ByteBuffer.wrap(bytes)
@@ -941,7 +945,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for DenseMatrix
private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
val bytes = new Array[Byte](8 * m.values.size)
val order = ByteOrder.nativeOrder()
@@ -973,7 +977,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
val n = v.indices.size
val indiceBytes = new Array[Byte](4 * n)
@@ -1015,7 +1019,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for LabeledPoint
private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val point: LabeledPoint = obj.asInstanceOf[LabeledPoint]
saveObjects(out, pickler, point.label, point.features)
}
@@ -1031,7 +1035,7 @@ private[spark] object SerDe extends Serializable {
// Pickler for Rating
private[python] class RatingPickler extends BasePickler[Rating] {
- def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
val rating: Rating = obj.asInstanceOf[Rating]
saveObjects(out, pickler, rating.user, rating.product, rating.rating)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 2ebc7fa5d4..068449aa1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -83,10 +83,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
private object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Hard-code class name string in case it changes in the future */
- def thisClassName = "org.apache.spark.mllib.classification.NaiveBayesModel"
+ def thisClassName: String = "org.apache.spark.mllib.classification.NaiveBayesModel"
/** Model data for model import/export */
case class Data(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
@@ -174,7 +174,7 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
*
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
- def run(data: RDD[LabeledPoint]) = {
+ def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 8956189ff1..3b6790cce4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -32,7 +32,7 @@ private[classification] object GLMClassificationModel {
object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Model data for import/export */
case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index e41f941fd2..0f8d6a3996 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -536,5 +536,5 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable
def this(array: Array[Double]) = this(Vectors.dense(array))
/** Converts the vector to a dense vector. */
- def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
+ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
index ea10bde5fa..a8378a76d2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -96,30 +96,30 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns precision for a given label (category)
* @param label the label.
*/
- def precision(label: Double) = {
+ def precision(label: Double): Double = {
val tp = tpPerClass(label)
val fp = fpPerClass.getOrElse(label, 0L)
- if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ if (tp + fp == 0) 0.0 else tp.toDouble / (tp + fp)
}
/**
* Returns recall for a given label (category)
* @param label the label.
*/
- def recall(label: Double) = {
+ def recall(label: Double): Double = {
val tp = tpPerClass(label)
val fn = fnPerClass.getOrElse(label, 0L)
- if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ if (tp + fn == 0) 0.0 else tp.toDouble / (tp + fn)
}
/**
* Returns f1-measure for a given label (category)
* @param label the label.
*/
- def f1Measure(label: Double) = {
+ def f1Measure(label: Double): Double = {
val p = precision(label)
val r = recall(label)
- if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ if((p + r) == 0) 0.0 else 2 * p * r / (p + r)
}
private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
@@ -130,7 +130,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based precision
* (equals to micro-averaged document-based precision)
*/
- lazy val microPrecision = {
+ lazy val microPrecision: Double = {
val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
sumTp.toDouble / (sumTp + sumFp)
}
@@ -139,7 +139,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based recall
* (equals to micro-averaged document-based recall)
*/
- lazy val microRecall = {
+ lazy val microRecall: Double = {
val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
sumTp.toDouble / (sumTp + sumFn)
}
@@ -148,7 +148,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
* Returns micro-averaged label-based f1-measure
* (equals to micro-averaged document-based f1-measure)
*/
- lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+ lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
/**
* Returns the sequence of labels in ascending order
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0e4a4d0085..fdd8848189 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -146,12 +146,16 @@ class DenseMatrix(
def this(numRows: Int, numCols: Int, values: Array[Double]) =
this(numRows, numCols, values, false)
- override def equals(o: Any) = o match {
+ override def equals(o: Any): Boolean = o match {
case m: DenseMatrix =>
m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
case _ => false
}
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(numRows : Integer, numCols: Integer, toArray)
+ }
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BDM[Double](numRows, numCols, values)
@@ -173,7 +177,7 @@ class DenseMatrix(
values(index(i, j)) = v
}
- override def copy = new DenseMatrix(numRows, numCols, values.clone())
+ override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone())
private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f))
@@ -431,7 +435,9 @@ class SparseMatrix(
}
}
- override def copy = new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
+ override def copy: SparseMatrix = {
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone())
+ }
private[mllib] def map(f: Double => Double) =
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index e9d25dcb7e..2cda9b252e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -183,6 +183,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
}
}
+ override def hashCode: Int = 7919
+
private[spark] override def asNullable: VectorUDT = this
}
@@ -478,7 +480,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
- override def apply(i: Int) = values(i)
+ override def apply(i: Int): Double = values(i)
override def copy: DenseVector = {
new DenseVector(values.clone())
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 1d25396313..3323ae7b1f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -49,7 +49,7 @@ private[mllib] class GridPartitioner(
private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt
private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt
- override val numPartitions = rowPartitions * colPartitions
+ override val numPartitions: Int = rowPartitions * colPartitions
/**
* Returns the index of the partition the input coordinate belongs to.
@@ -85,6 +85,14 @@ private[mllib] class GridPartitioner(
false
}
}
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(
+ rows: java.lang.Integer,
+ cols: java.lang.Integer,
+ rowsPerPart: java.lang.Integer,
+ colsPerPart: java.lang.Integer)
+ }
}
private[mllib] object GridPartitioner {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index 405bae62ee..9349ecaa13 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -56,7 +56,7 @@ class UniformGenerator extends RandomDataGenerator[Double] {
random.nextDouble()
}
- override def setSeed(seed: Long) = random.setSeed(seed)
+ override def setSeed(seed: Long): Unit = random.setSeed(seed)
override def copy(): UniformGenerator = new UniformGenerator()
}
@@ -75,7 +75,7 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] {
random.nextGaussian()
}
- override def setSeed(seed: Long) = random.setSeed(seed)
+ override def setSeed(seed: Long): Unit = random.setSeed(seed)
override def copy(): StandardNormalGenerator = new StandardNormalGenerator()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index bd7e340ca2..b55944f74f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -32,7 +32,7 @@ private[regression] object GLMRegressionModel {
object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
/** Model data for model import/export */
case class Data(weights: Vector, intercept: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 8d5c36da32..ada227c200 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -83,10 +83,13 @@ class Strategy (
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- def isMulticlassClassification =
+ def isMulticlassClassification: Boolean = {
algo == Classification && numClasses > 2
- def isMulticlassWithCategoricalFeatures
- = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ }
+
+ def isMulticlassWithCategoricalFeatures: Boolean = {
+ isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+ }
/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index b7950e0078..5ac10f3fd3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -71,7 +71,7 @@ object Entropy extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c946db9c0d..19d318203c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -67,7 +67,7 @@ object Gini extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index df9eafa5da..7104a7fa4d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -58,7 +58,7 @@ object Variance extends Impurity {
* Get this impurity instance.
* This is useful for passing impurity parameters to a Strategy in Java.
*/
- def instance = this
+ def instance: this.type = this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 8a57ebc387..c9bafd60fb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -120,10 +120,10 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
private[tree] object SaveLoadV1_0 {
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.DecisionTreeModel"
case class PredictData(predict: Double, prob: Double) {
def toPredict: Predict = new Predict(predict, prob)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 80990aa9a6..f209fdafd3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,23 +38,32 @@ class InformationGainStats(
val leftPredict: Predict,
val rightPredict: Predict) extends Serializable {
- override def toString = {
+ override def toString: String = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
- override def equals(o: Any) =
- o match {
- case other: InformationGainStats => {
- gain == other.gain &&
- impurity == other.impurity &&
- leftImpurity == other.leftImpurity &&
- rightImpurity == other.rightImpurity &&
- leftPredict == other.leftPredict &&
- rightPredict == other.rightPredict
- }
- case _ => false
- }
+ override def equals(o: Any): Boolean = o match {
+ case other: InformationGainStats =>
+ gain == other.gain &&
+ impurity == other.impurity &&
+ leftImpurity == other.leftImpurity &&
+ rightImpurity == other.rightImpurity &&
+ leftPredict == other.leftPredict &&
+ rightPredict == other.rightPredict
+
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(
+ gain: java.lang.Double,
+ impurity: java.lang.Double,
+ leftImpurity: java.lang.Double,
+ rightImpurity: java.lang.Double,
+ leftPredict,
+ rightPredict)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index d961081d18..4f72bb8014 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -50,8 +50,10 @@ class Node (
var rightNode: Option[Node],
var stats: Option[InformationGainStats]) extends Serializable with Logging {
- override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "impurity = " + impurity + "split = " + split + ", stats = " + stats
+ override def toString: String = {
+ "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
+ "impurity = " + impurity + "split = " + split + ", stats = " + stats
+ }
/**
* build the left node and right nodes if not leaf
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index ad4c0dbbfb..25990af7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -29,7 +29,7 @@ class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
- override def toString = {
+ override def toString: String = {
"predict = %f, prob = %f".format(predict, prob)
}
@@ -39,4 +39,8 @@ class Predict(
case _ => false
}
}
+
+ override def hashCode: Int = {
+ com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index b7a85f5854..fb35e70a8d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -38,9 +38,10 @@ case class Split(
featureType: FeatureType,
categories: List[Double]) {
- override def toString =
+ override def toString: String = {
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
", categories = " + categories
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 30a8f7ca30..f160852c69 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -79,7 +79,7 @@ object RandomForestModel extends Loader[RandomForestModel] {
private object SaveLoadV1_0 {
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.model.RandomForestModel"
}
}
@@ -130,7 +130,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
private object SaveLoadV1_0 {
// Hard-code class name string in case it changes in the future
- def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+ def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
}
}
@@ -257,7 +257,7 @@ private[tree] object TreeEnsembleModel extends Logging {
import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
- def thisFormatVersion = "1.0"
+ def thisFormatVersion: String = "1.0"
case class Metadata(
algo: String,