1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
import DecisionTreeRegressorSuite.compareAPIs
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
override def beforeAll() {
super.beforeAll()
categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
}
/////////////////////////////////////////////////////////////////////////////
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////
test("Regression stump with 3-ary (ordered) categorical features") {
val dt = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(100)
val categoricalFeatures = Map(0 -> 3, 1-> 3)
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
}
test("Regression stump with binary (ordered) categorical features") {
val dt = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(100)
val categoricalFeatures = Map(0 -> 2, 1-> 2)
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
}
test("copied model must have the same parent") {
val categoricalFeatures = Map(0 -> 2, 1-> 2)
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
val model = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(8).fit(df)
MLTestingUtils.checkCopy(model)
}
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
// TODO: test("model save/load") SPARK-6725
}
private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
/**
* Train 2 decision trees on the given dataset, one using the old API and one using the new API.
* Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
*/
def compareAPIs(
data: RDD[LabeledPoint],
dt: DecisionTreeRegressor,
categoricalFeatures: Map[Int, Int]): Unit = {
val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures)
val oldTree = OldDecisionTree.train(data, oldStrategy)
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
val newTree = dt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
TreeTests.checkEqual(oldTreeAsNew, newTree)
assert(newTree.numFeatures === numFeatures)
}
}
|