/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeRegressorSuite.compareAPIs private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ override def beforeAll() { super.beforeAll() categoricalDataPointsRDD = sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) } ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// test("Regression stump with 3-ary (ordered) categorical features") { val dt = new DecisionTreeRegressor() .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) val categoricalFeatures = Map(0 -> 3, 1-> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } test("Regression stump with binary (ordered) categorical features") { val dt = new DecisionTreeRegressor() .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) val categoricalFeatures = Map(0 -> 2, 1-> 2) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// // TODO: test("model save/load") SPARK-6725 } private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. */ def compareAPIs( data: RDD[LabeledPoint], dt: DecisionTreeRegressor, categoricalFeatures: Map[Int, Int]): Unit = { val oldStrategy = dt.getOldStrategy(categoricalFeatures) val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) // Use parent, fittingParamMap from newTree since these are not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } }