openModeller
Version 1.4.0
|
00001 00027 #include "rf_alg.hh" 00028 00029 #include <openmodeller/Sampler.hh> 00030 #include <openmodeller/Exceptions.hh> 00031 00032 #include <string.h> 00033 #include <stdio.h> 00034 #include <stdlib.h> 00035 #include <sstream> 00036 00037 #include "librf/instance_set.h" 00038 #include "librf/tree.h" 00039 #include "librf/weights.h" 00040 00041 using namespace std; 00042 using namespace librf; 00043 00044 /****************************************************************/ 00045 /********************** Algorithm's Metadata ********************/ 00046 00047 #define NUM_PARAM 3 00048 00049 #define NUMTREES_ID "NumTrees" 00050 #define K_ID "VarsPerTree" 00051 #define UNSUP_ID "ForceUnsupervisedLearning" 00052 00053 #define RF_LOG_PREFIX "RfAlgorithm: " 00054 00055 /******************************/ 00056 /*** Algorithm's parameters ***/ 00057 00058 static AlgParamMetadata parameters[NUM_PARAM] = { 00059 00060 // Number of Trees 00061 { 00062 NUMTREES_ID, // Id. 00063 "Number of trees", // Name. 00064 Integer, // Type. 00065 "Number of trees", // Overview 00066 "Number of trees", // Description. 00067 1, // Not zero if the parameter has lower limit. 00068 1, // Parameter's lower limit. 00069 1, // Not zero if the parameter has upper limit. 00070 1000, // Parameter's upper limit. 00071 "10" // Parameter's typical (default) value. 00072 }, 00073 // Number of variables per tree 00074 { 00075 K_ID, // Id. 00076 "Number of variables per tree", // Name. 00077 Integer, // Type. 00078 "Number of variables per tree (zero defaults to the square root of the number of layers)", // Overview 00079 "Number of variables per tree (zero defaults to the square root of the number of layers)", // Description. 00080 0, // Not zero if the parameter has lower limit. 00081 0, // Parameter's lower limit. 00082 0, // Not zero if the parameter has upper limit. 00083 0, // Parameter's upper limit. 00084 "0" // Parameter's typical (default) value. 00085 }, 00086 // Force unsupervised learning 00087 { 00088 UNSUP_ID, // Id. 00089 "Force unsupervised learning", // Name. 00090 Integer, // Type. 00091 "Force unsupervised learning", // Overview 00092 "When absence points are provided, this parameter can be used to ignore them forcing unsupervised learning. Note that if no absences are provided, unsupervised learning will be used anyway.", // Description. 00093 1, // Not zero if the parameter has lower limit. 00094 0, // Parameter's lower limit. 00095 1, // Not zero if the parameter has upper limit. 00096 1, // Parameter's upper limit. 00097 "0" // Parameter's typical (default) value. 00098 }, 00099 }; 00100 00101 /************************************/ 00102 /*** Algorithm's general metadata ***/ 00103 00104 static AlgMetadata metadata = { 00105 00106 "RF", // Id. 00107 "Random Forests", // Name. 00108 "0.2", // Version. 00109 00110 // Overview 00111 "Random Forests", 00112 00113 // Description. 00114 "Random Forests", 00115 00116 "Leo Breiman & Adele Cutler", // Algorithm author. 00117 "Breiman, L. (2001). Random forests. Machine Learning, 45, 5-32.", // Bibliography. 00118 00119 "Renato De Giovanni", // Code author. 00120 "renato [at] cria . org . br", // Code author's contact. 00121 00122 0, // Does not accept categorical data. 00123 0, // Does not need (pseudo)absence points. 00124 00125 NUM_PARAM, // Algorithm's parameters. 00126 parameters 00127 }; 00128 00129 /****************************************************************/ 00130 /****************** Algorithm's factory function ****************/ 00131 00132 OM_ALG_DLL_EXPORT 00133 AlgorithmImpl * 00134 algorithmFactory() 00135 { 00136 return new RfAlgorithm(); 00137 } 00138 00139 OM_ALG_DLL_EXPORT 00140 AlgMetadata const * 00141 algorithmMetadata() 00142 { 00143 return &metadata; 00144 } 00145 00146 00147 /*********************************************/ 00148 /************** SVM algorithm ****************/ 00149 00150 /*******************/ 00151 /*** constructor ***/ 00152 00153 RfAlgorithm::RfAlgorithm() : 00154 AlgorithmImpl( &metadata ), 00155 _done( false ), 00156 _initialized( false ) 00157 { 00158 } 00159 00160 00161 /******************/ 00162 /*** destructor ***/ 00163 00164 RfAlgorithm::~RfAlgorithm() 00165 { 00166 if ( _initialized ) { 00167 00168 delete _set; 00169 } 00170 00171 for ( unsigned int i = 0; i < _trees.size(); ++i ) { 00172 00173 delete _trees[i]; 00174 } 00175 } 00176 00177 /**************************/ 00178 /*** need Normalization ***/ 00179 int RfAlgorithm::needNormalization() 00180 { 00181 return 0; 00182 } 00183 00184 /******************/ 00185 /*** initialize ***/ 00186 int 00187 RfAlgorithm::initialize() 00188 { 00189 int num_layers = _samp->numIndependent(); 00190 00191 // Number of trees 00192 if ( ! getParameter( NUMTREES_ID, &_num_trees ) ) { 00193 00194 Log::instance()->error( RF_LOG_PREFIX "Parameter '" NUMTREES_ID "' not passed.\n" ); 00195 return 0; 00196 } 00197 00198 if ( _num_trees < 1 ) { 00199 00200 Log::instance()->error( RF_LOG_PREFIX "Parameter '" NUMTREES_ID "' must be greater than zero.\n" ); 00201 return 0; 00202 } 00203 00204 // Don't allow people to use too much memory 00205 if ( _num_trees > 1000 ) { 00206 00207 Log::instance()->error( RF_LOG_PREFIX "Parameter '" NUMTREES_ID "' is greater than 1000.\n" ); 00208 return 0; 00209 } 00210 00211 _trees.reserve( _num_trees ); 00212 00213 // Number of variables per tree 00214 if ( ! getParameter( K_ID, &_k ) ) { 00215 00216 Log::instance()->error( RF_LOG_PREFIX "Parameter '" K_ID "' not passed.\n" ); 00217 return 0; 00218 } 00219 00220 if ( _k < 1 ) { 00221 00222 _k = int( sqrt( double( num_layers ) ) ); 00223 } 00224 00225 // Unsupervised learning 00226 bool force_unsupervised_learning = false; 00227 int unsup; 00228 if ( getParameter( UNSUP_ID, &unsup ) && unsup == 1 ) { 00229 00230 force_unsupervised_learning = true; 00231 } 00232 00233 _class_weights.resize(2, 1); 00234 00235 // Check the number of presences 00236 int num_presences = _samp->numPresence(); 00237 00238 if ( num_presences == 0 ) { 00239 00240 Log::instance()->warn( RF_LOG_PREFIX "No presence points inside the mask!\n" ); 00241 return 0; 00242 } 00243 00244 // Load input points 00245 00246 unsigned int seed = (unsigned int)_rand.get(); 00247 00248 stringstream sdata(""); 00249 stringstream slabels(""); 00250 00251 OccurrencesImpl::const_iterator p_iterator; 00252 OccurrencesImpl::const_iterator p_end; 00253 00254 OccurrencesPtr presences = _samp->getPresences(); 00255 00256 p_iterator = presences->begin(); 00257 p_end = presences->end(); 00258 00259 while ( p_iterator != p_end ) { 00260 00261 Sample presence = (*p_iterator)->environment(); 00262 00263 _sampleToLine( presence, sdata ); 00264 00265 slabels << "0" << endl; // presence 00266 00267 ++p_iterator; 00268 } 00269 00270 if ( _samp->numAbsence() && ! force_unsupervised_learning ) { 00271 00272 OccurrencesPtr absences = _samp->getAbsences(); 00273 00274 p_iterator = absences->begin(); 00275 p_end = absences->end(); 00276 00277 while ( p_iterator != p_end ) { 00278 00279 Sample absence = (*p_iterator)->environment(); 00280 00281 _sampleToLine( absence, sdata ); 00282 00283 slabels << "1" << endl; // absence 00284 00285 ++p_iterator; 00286 } 00287 00288 istream data( sdata.rdbuf() ); 00289 istream labels( slabels.rdbuf() ); 00290 00291 _set = InstanceSet::load_csv_and_labels( data, labels ); 00292 } 00293 else { 00294 00295 // Prepare for unsupervised learning 00296 00297 istream data( sdata.rdbuf() ); 00298 00299 _set = InstanceSet::load_unsupervised( data, &seed ); 00300 } 00301 00302 _initialized = true; 00303 00304 return 1; 00305 } 00306 00307 00308 /**********************/ 00309 /*** sample to line ***/ 00310 void 00311 RfAlgorithm::_sampleToLine( Sample sample, stringstream& ss ) const 00312 { 00313 for ( unsigned int j = 0; j < sample.size(); ++j ) { 00314 00315 ss << sample[j] << ","; 00316 } 00317 00318 ss << endl; 00319 } 00320 00321 00322 /***************/ 00323 /*** iterate ***/ 00324 int 00325 RfAlgorithm::iterate() 00326 { 00327 if ( (int)_trees.size() < _num_trees ) { 00328 00329 weight_list* w = new weight_list( _set->size(), _set->size()); 00330 00331 // sample with replacement 00332 for ( unsigned int j = 0; j < _set->size(); ++j ) { 00333 00334 int instance = _rand.get( 0, _set->size() - 1 ); 00335 w->add( instance, _class_weights[_set->label(instance)] ); 00336 } 00337 00338 Tree* tree = new Tree( *_set, w, _k, 1, 0, _rand.get(1000) ); 00339 tree->grow(); 00340 00341 _trees.push_back(tree); 00342 } 00343 else { 00344 00345 _done = true; 00346 } 00347 00348 return 1; 00349 } 00350 00351 /********************/ 00352 /*** get Progress ***/ 00353 float RfAlgorithm::getProgress() const 00354 { 00355 if ( done() ) { 00356 00357 return 1.0; 00358 } 00359 00360 return (float)_trees.size() / (float)_num_trees; 00361 } 00362 00363 00364 /************/ 00365 /*** done ***/ 00366 int 00367 RfAlgorithm::done() const 00368 { 00369 return _done; 00370 } 00371 00372 /*****************/ 00373 /*** get Value ***/ 00374 Scalar 00375 RfAlgorithm::getValue( const Sample& x ) const 00376 { 00377 stringstream sdata(""); 00378 00379 _sampleToLine( x, sdata ); 00380 00381 istream data( sdata.rdbuf() ); 00382 00383 stringstream slabels("0"); 00384 00385 istream labels( slabels.rdbuf() ); 00386 00387 InstanceSet* set = InstanceSet::load_csv_and_labels( data, labels ); 00388 00389 DiscreteDist votes; 00390 00391 for ( unsigned int i = 0; i < _trees.size(); ++i ) { 00392 00393 int predict = _trees[i]->predict( *set, 0 ); 00394 votes.add( predict ); 00395 } 00396 00397 float prob = votes.percentage(0); 00398 00399 delete set; 00400 00401 return (double)prob; 00402 } 00403 00404 /***********************/ 00405 /*** get Convergence ***/ 00406 int 00407 RfAlgorithm::getConvergence( Scalar * const val ) const 00408 { 00409 *val = 1.0; 00410 return 1; 00411 } 00412 00413 /****************************************************************/ 00414 /****************** configuration *******************************/ 00415 void 00416 RfAlgorithm::_getConfiguration( ConfigurationPtr& config ) const 00417 { 00418 if ( ! _done ) 00419 return; 00420 00421 ConfigurationPtr model_config( new ConfigurationImpl("Rf") ); 00422 config->addSubsection( model_config ); 00423 00424 model_config->addNameValue( "Trees", _num_trees ); 00425 model_config->addNameValue( "K", _k ); 00426 00427 Tree* p_tree = NULL; 00428 tree_node* p_node = NULL; 00429 00430 unsigned int num_nodes; 00431 00432 char buffer[5]; 00433 00434 for ( int i=0; i < _num_trees; ++i ) { 00435 00436 p_tree = _trees[i]; 00437 00438 ConfigurationPtr tree_config( new ConfigurationImpl("Tree") ); 00439 00440 num_nodes = p_tree->num_nodes(); 00441 00442 tree_config->addNameValue( "Nodes", (int)num_nodes ); 00443 00444 sprintf( buffer, "%4.2f", p_tree->training_accuracy() ); 00445 00446 tree_config->addNameValue( "Accuracy", buffer ); 00447 tree_config->addNameValue( "Split", (int)p_tree->num_split_nodes() ); 00448 tree_config->addNameValue( "Terminal", (int)p_tree->num_terminal_nodes() ); 00449 00450 for ( unsigned int j= 0; j < num_nodes; ++j ) { 00451 00452 p_node = p_tree->get_node( j ); 00453 00454 ConfigurationPtr node_config( new ConfigurationImpl("Node") ); 00455 00456 librf::NodeStatusType status = p_node->status; 00457 00458 node_config->addNameValue( "Status", (int)status ); 00459 00460 if ( status == SPLIT ) { 00461 00462 node_config->addNameValue( "L", (int)p_node->left ); 00463 node_config->addNameValue( "R", (int)p_node->right ); 00464 node_config->addNameValue( "A", (int)p_node->attr ); 00465 node_config->addNameValue( "S", (float)p_node->split_point ); 00466 } 00467 else if ( status == TERMINAL ) { 00468 00469 node_config->addNameValue( "V", (char)p_node->label ); 00470 } 00471 00472 tree_config->addSubsection( node_config ); 00473 } 00474 00475 model_config->addSubsection( tree_config ); 00476 } 00477 } 00478 00479 void 00480 RfAlgorithm::_setConfiguration( const ConstConfigurationPtr& config ) 00481 { 00482 ConstConfigurationPtr model_config = config->getSubsection( "Rf", false ); 00483 00484 if ( ! model_config ) 00485 return; 00486 00487 _num_trees = model_config->getAttributeAsInt( "Trees", 0 ); 00488 00489 _k = model_config->getAttributeAsInt( "K", 0 ); 00490 00491 _trees.reserve( _num_trees ); 00492 00493 Configuration::subsection_list trees = model_config->getAllSubsections(); 00494 00495 Configuration::subsection_list::iterator tree = trees.begin(); 00496 Configuration::subsection_list::iterator last_tree = trees.end(); 00497 00498 for ( ; tree != last_tree; ++tree ) { 00499 00500 if ( (*tree)->getName() != "Tree" ) { 00501 00502 continue; 00503 } 00504 00505 Tree* my_tree = new Tree(); 00506 00507 Configuration::subsection_list nodes = (*tree)->getAllSubsections(); 00508 00509 Configuration::subsection_list::iterator node = nodes.begin(); 00510 Configuration::subsection_list::iterator last_node = nodes.end(); 00511 00512 for ( ; node != last_node; ++node ) { 00513 00514 if ( (*node)->getName() != "Node" ) { 00515 00516 continue; 00517 } 00518 00519 int status = (*node)->getAttributeAsInt( "Status", 0 ); 00520 00521 tree_node my_node; 00522 00523 if ( status == SPLIT ) { 00524 00525 my_node.status = SPLIT; 00526 my_node.left = (*node)->getAttributeAsInt( "L", 0 ); 00527 my_node.right = (*node)->getAttributeAsInt( "R", 0 ); 00528 my_node.attr = (*node)->getAttributeAsInt( "A", 0 ); 00529 double split_point = (*node)->getAttributeAsDouble( "S", 0.0 ); 00530 my_node.split_point = (float)split_point; 00531 } 00532 else if ( status == TERMINAL ) { 00533 00534 my_node.status = TERMINAL; 00535 int label = (*node)->getAttributeAsInt( "V", 0 ); 00536 my_node.label = uchar(label); 00537 } 00538 else { 00539 00540 continue; 00541 } 00542 00543 my_tree->add_node( my_node ); 00544 } 00545 00546 _trees.push_back( my_tree ); 00547 } 00548 00549 _initialized = true; 00550 00551 _done = true; 00552 }