openModeller  Version 1.4.0
rf_alg.cpp
Go to the documentation of this file.
00001 
00027 #include "rf_alg.hh"
00028 
00029 #include <openmodeller/Sampler.hh>
00030 #include <openmodeller/Exceptions.hh>
00031 
00032 #include <string.h>
00033 #include <stdio.h>
00034 #include <stdlib.h>
00035 #include <sstream>
00036 
00037 #include "librf/instance_set.h"
00038 #include "librf/tree.h"
00039 #include "librf/weights.h"
00040 
00041 using namespace std;
00042 using namespace librf;
00043 
00044 /****************************************************************/
00045 /********************** Algorithm's Metadata ********************/
00046 
00047 #define NUM_PARAM 3
00048 
00049 #define NUMTREES_ID   "NumTrees"
00050 #define K_ID          "VarsPerTree"
00051 #define UNSUP_ID      "ForceUnsupervisedLearning"
00052 
00053 #define RF_LOG_PREFIX "RfAlgorithm: "
00054 
00055 /******************************/
00056 /*** Algorithm's parameters ***/
00057 
00058 static AlgParamMetadata parameters[NUM_PARAM] = {
00059 
00060   // Number of Trees
00061   {
00062     NUMTREES_ID,        // Id.
00063     "Number of trees",  // Name.
00064     Integer,            // Type.
00065     "Number of trees",  // Overview
00066     "Number of trees",  // Description.
00067     1,         // Not zero if the parameter has lower limit.
00068     1,         // Parameter's lower limit.
00069     1,         // Not zero if the parameter has upper limit.
00070     1000,      // Parameter's upper limit.
00071     "10"       // Parameter's typical (default) value.
00072   },
00073   // Number of variables per tree
00074   {
00075     K_ID,     // Id.
00076     "Number of variables per tree", // Name.
00077     Integer,  // Type.
00078     "Number of variables per tree (zero defaults to the square root of the number of layers)", // Overview
00079     "Number of variables per tree (zero defaults to the square root of the number of layers)", // Description.
00080     0,         // Not zero if the parameter has lower limit.
00081     0,         // Parameter's lower limit.
00082     0,         // Not zero if the parameter has upper limit.
00083     0,         // Parameter's upper limit.
00084     "0"        // Parameter's typical (default) value.
00085   },
00086   // Force unsupervised learning
00087   {
00088     UNSUP_ID,     // Id.
00089     "Force unsupervised learning", // Name.
00090     Integer,  // Type.
00091     "Force unsupervised learning", // Overview
00092     "When absence points are provided, this parameter can be used to ignore them forcing unsupervised learning. Note that if no absences are provided, unsupervised learning will be used anyway.", // Description.
00093     1,         // Not zero if the parameter has lower limit.
00094     0,         // Parameter's lower limit.
00095     1,         // Not zero if the parameter has upper limit.
00096     1,         // Parameter's upper limit.
00097     "0"        // Parameter's typical (default) value.
00098   },
00099 };
00100 
00101 /************************************/
00102 /*** Algorithm's general metadata ***/
00103 
00104 static AlgMetadata metadata = {
00105 
00106   "RF",        // Id.
00107   "Random Forests",  // Name.
00108   "0.2",             // Version.
00109 
00110   // Overview
00111   "Random Forests",
00112 
00113   // Description.
00114   "Random Forests",
00115 
00116   "Leo Breiman & Adele Cutler", // Algorithm author.
00117   "Breiman, L. (2001). Random forests. Machine Learning, 45, 5-32.", // Bibliography.
00118 
00119   "Renato De Giovanni", // Code author.
00120   "renato [at] cria . org . br", // Code author's contact.
00121 
00122   0, // Does not accept categorical data.
00123   0, // Does not need (pseudo)absence points.
00124 
00125   NUM_PARAM, // Algorithm's parameters.
00126   parameters
00127 };
00128 
00129 /****************************************************************/
00130 /****************** Algorithm's factory function ****************/
00131 
00132 OM_ALG_DLL_EXPORT
00133 AlgorithmImpl *
00134 algorithmFactory()
00135 {
00136   return new RfAlgorithm();
00137 }
00138 
00139 OM_ALG_DLL_EXPORT
00140 AlgMetadata const *
00141 algorithmMetadata()
00142 {
00143   return &metadata;
00144 }
00145 
00146 
00147 /*********************************************/
00148 /************** SVM algorithm ****************/
00149 
00150 /*******************/
00151 /*** constructor ***/
00152 
00153 RfAlgorithm::RfAlgorithm() :
00154   AlgorithmImpl( &metadata ),
00155   _done( false ),
00156   _initialized( false )
00157 {
00158 }
00159 
00160 
00161 /******************/
00162 /*** destructor ***/
00163 
00164 RfAlgorithm::~RfAlgorithm()
00165 {
00166   if ( _initialized ) {
00167 
00168     delete _set;
00169   }
00170 
00171   for ( unsigned int i = 0; i < _trees.size(); ++i ) {
00172 
00173     delete _trees[i];
00174   }
00175 }
00176 
00177 /**************************/
00178 /*** need Normalization ***/
00179 int RfAlgorithm::needNormalization()
00180 {
00181   return 0;
00182 }
00183 
00184 /******************/
00185 /*** initialize ***/
00186 int
00187 RfAlgorithm::initialize()
00188 {
00189   int num_layers = _samp->numIndependent();
00190 
00191   // Number of trees
00192   if ( ! getParameter( NUMTREES_ID, &_num_trees ) ) {
00193 
00194     Log::instance()->error( RF_LOG_PREFIX "Parameter '" NUMTREES_ID "' not passed.\n" );
00195     return 0;
00196   }
00197 
00198   if ( _num_trees < 1 ) {
00199 
00200     Log::instance()->error( RF_LOG_PREFIX "Parameter '" NUMTREES_ID "' must be greater than zero.\n" );
00201     return 0;
00202   }
00203 
00204   // Don't allow people to use too much memory
00205   if ( _num_trees > 1000 ) {
00206 
00207     Log::instance()->error( RF_LOG_PREFIX "Parameter '" NUMTREES_ID "' is greater than 1000.\n" );
00208     return 0;
00209   }
00210 
00211   _trees.reserve( _num_trees );
00212 
00213   // Number of variables per tree
00214   if ( ! getParameter( K_ID, &_k ) ) {
00215 
00216     Log::instance()->error( RF_LOG_PREFIX "Parameter '" K_ID "' not passed.\n" );
00217     return 0;
00218   }
00219 
00220   if ( _k < 1 ) {
00221 
00222     _k = int( sqrt( double( num_layers ) ) );
00223   }
00224 
00225   // Unsupervised learning
00226   bool force_unsupervised_learning = false;
00227   int unsup;
00228   if ( getParameter( UNSUP_ID, &unsup ) && unsup == 1 ) {
00229 
00230     force_unsupervised_learning = true;
00231   }
00232 
00233   _class_weights.resize(2, 1);
00234 
00235   // Check the number of presences
00236   int num_presences = _samp->numPresence();
00237 
00238   if ( num_presences == 0 ) {
00239 
00240     Log::instance()->warn( RF_LOG_PREFIX "No presence points inside the mask!\n" );
00241     return 0;
00242   }
00243 
00244   // Load input points
00245 
00246   unsigned int seed = (unsigned int)_rand.get();
00247 
00248   stringstream sdata("");
00249   stringstream slabels("");
00250 
00251   OccurrencesImpl::const_iterator p_iterator;
00252   OccurrencesImpl::const_iterator p_end;
00253 
00254   OccurrencesPtr presences = _samp->getPresences();
00255 
00256   p_iterator = presences->begin();
00257   p_end = presences->end();
00258 
00259   while ( p_iterator != p_end ) {
00260 
00261     Sample presence = (*p_iterator)->environment();
00262 
00263     _sampleToLine( presence, sdata );
00264 
00265     slabels << "0" << endl; // presence
00266 
00267     ++p_iterator;
00268   }
00269 
00270   if ( _samp->numAbsence() && ! force_unsupervised_learning ) {
00271 
00272     OccurrencesPtr absences = _samp->getAbsences();
00273 
00274     p_iterator = absences->begin();
00275     p_end = absences->end();
00276 
00277     while ( p_iterator != p_end ) {
00278 
00279       Sample absence = (*p_iterator)->environment();
00280 
00281       _sampleToLine( absence, sdata );
00282 
00283       slabels << "1" << endl; // absence
00284 
00285       ++p_iterator;
00286     }
00287 
00288     istream data( sdata.rdbuf() );
00289     istream labels( slabels.rdbuf() );
00290 
00291     _set = InstanceSet::load_csv_and_labels( data, labels );
00292   }
00293   else {
00294 
00295     // Prepare for unsupervised learning
00296 
00297     istream data( sdata.rdbuf() );
00298 
00299     _set = InstanceSet::load_unsupervised( data, &seed );
00300   }
00301 
00302   _initialized = true;
00303 
00304   return 1;
00305 }
00306 
00307 
00308 /**********************/
00309 /*** sample to line ***/
00310 void
00311 RfAlgorithm::_sampleToLine( Sample sample, stringstream& ss ) const
00312 {
00313   for ( unsigned int j = 0; j < sample.size(); ++j ) {
00314 
00315     ss << sample[j] << ",";
00316   }
00317 
00318   ss << endl;
00319 }
00320 
00321 
00322 /***************/
00323 /*** iterate ***/
00324 int
00325 RfAlgorithm::iterate()
00326 {
00327   if ( (int)_trees.size() < _num_trees ) {
00328 
00329     weight_list* w = new weight_list( _set->size(), _set->size());
00330 
00331     // sample with replacement
00332     for ( unsigned int j = 0; j < _set->size(); ++j ) {
00333 
00334       int instance = _rand.get( 0, _set->size() - 1 );
00335       w->add( instance, _class_weights[_set->label(instance)] );
00336     }
00337 
00338     Tree* tree = new Tree( *_set, w, _k, 1, 0, _rand.get(1000) );
00339     tree->grow();
00340 
00341     _trees.push_back(tree);
00342   }
00343   else {
00344 
00345     _done = true;
00346   }
00347 
00348   return 1;
00349 }
00350 
00351 /********************/
00352 /*** get Progress ***/
00353 float RfAlgorithm::getProgress() const
00354 {
00355   if ( done() ) {
00356 
00357     return 1.0;
00358   }
00359 
00360   return (float)_trees.size() / (float)_num_trees;
00361 }
00362 
00363 
00364 /************/
00365 /*** done ***/
00366 int
00367 RfAlgorithm::done() const
00368 {
00369   return _done;
00370 }
00371 
00372 /*****************/
00373 /*** get Value ***/
00374 Scalar
00375 RfAlgorithm::getValue( const Sample& x ) const
00376 {
00377   stringstream sdata("");
00378 
00379   _sampleToLine( x, sdata );
00380 
00381   istream data( sdata.rdbuf() );
00382 
00383   stringstream slabels("0");
00384 
00385   istream labels( slabels.rdbuf() );
00386 
00387   InstanceSet* set = InstanceSet::load_csv_and_labels( data, labels );
00388 
00389   DiscreteDist votes;
00390 
00391   for ( unsigned int i = 0; i < _trees.size(); ++i ) {
00392 
00393     int predict = _trees[i]->predict( *set, 0 );
00394     votes.add( predict );
00395   }
00396 
00397   float prob = votes.percentage(0);
00398 
00399   delete set;
00400 
00401   return (double)prob;
00402 }
00403 
00404 /***********************/
00405 /*** get Convergence ***/
00406 int
00407 RfAlgorithm::getConvergence( Scalar * const val ) const
00408 {
00409   *val = 1.0;
00410   return 1;
00411 }
00412 
00413 /****************************************************************/
00414 /****************** configuration *******************************/
00415 void
00416 RfAlgorithm::_getConfiguration( ConfigurationPtr& config ) const
00417 {
00418   if ( ! _done )
00419     return;
00420 
00421   ConfigurationPtr model_config( new ConfigurationImpl("Rf") );
00422   config->addSubsection( model_config );
00423 
00424   model_config->addNameValue( "Trees", _num_trees );
00425   model_config->addNameValue( "K", _k );
00426 
00427   Tree* p_tree = NULL;
00428   tree_node* p_node = NULL;
00429 
00430   unsigned int num_nodes;
00431 
00432   char buffer[5];
00433 
00434   for ( int i=0; i < _num_trees; ++i ) {
00435 
00436     p_tree = _trees[i];
00437 
00438     ConfigurationPtr tree_config( new ConfigurationImpl("Tree") );
00439 
00440     num_nodes = p_tree->num_nodes();
00441 
00442     tree_config->addNameValue( "Nodes", (int)num_nodes );
00443 
00444     sprintf( buffer, "%4.2f", p_tree->training_accuracy() );
00445 
00446     tree_config->addNameValue( "Accuracy", buffer );
00447     tree_config->addNameValue( "Split", (int)p_tree->num_split_nodes() );
00448     tree_config->addNameValue( "Terminal", (int)p_tree->num_terminal_nodes() );
00449 
00450     for ( unsigned int j= 0; j < num_nodes; ++j ) {
00451     
00452       p_node = p_tree->get_node( j );    
00453 
00454       ConfigurationPtr node_config( new ConfigurationImpl("Node") );
00455 
00456       librf::NodeStatusType status = p_node->status;
00457 
00458       node_config->addNameValue( "Status", (int)status );
00459 
00460       if ( status == SPLIT ) {
00461 
00462         node_config->addNameValue( "L", (int)p_node->left );
00463         node_config->addNameValue( "R", (int)p_node->right );
00464         node_config->addNameValue( "A", (int)p_node->attr );
00465         node_config->addNameValue( "S", (float)p_node->split_point );
00466       }
00467       else if ( status == TERMINAL ) {
00468 
00469         node_config->addNameValue( "V", (char)p_node->label );
00470       }
00471 
00472       tree_config->addSubsection( node_config ); 
00473     }
00474 
00475     model_config->addSubsection( tree_config ); 
00476   }  
00477 }
00478 
00479 void
00480 RfAlgorithm::_setConfiguration( const ConstConfigurationPtr& config )
00481 {
00482   ConstConfigurationPtr model_config = config->getSubsection( "Rf", false );
00483 
00484   if ( ! model_config )
00485     return;
00486 
00487   _num_trees = model_config->getAttributeAsInt( "Trees", 0 );
00488 
00489   _k = model_config->getAttributeAsInt( "K", 0 );
00490 
00491   _trees.reserve( _num_trees );
00492 
00493   Configuration::subsection_list trees = model_config->getAllSubsections();
00494 
00495   Configuration::subsection_list::iterator tree = trees.begin();
00496   Configuration::subsection_list::iterator last_tree = trees.end();
00497 
00498   for ( ; tree != last_tree; ++tree ) {
00499 
00500     if ( (*tree)->getName() != "Tree" ) {
00501 
00502       continue;
00503     }
00504 
00505     Tree* my_tree = new Tree();
00506 
00507     Configuration::subsection_list nodes = (*tree)->getAllSubsections();
00508 
00509     Configuration::subsection_list::iterator node = nodes.begin();
00510     Configuration::subsection_list::iterator last_node = nodes.end();
00511 
00512     for ( ; node != last_node; ++node ) {
00513 
00514       if ( (*node)->getName() != "Node" ) {
00515 
00516         continue;
00517       }
00518 
00519       int status = (*node)->getAttributeAsInt( "Status", 0 );
00520 
00521       tree_node my_node;
00522       
00523       if ( status == SPLIT ) {
00524 
00525         my_node.status = SPLIT;
00526         my_node.left = (*node)->getAttributeAsInt( "L", 0 );
00527         my_node.right = (*node)->getAttributeAsInt( "R", 0 );
00528         my_node.attr = (*node)->getAttributeAsInt( "A", 0 );
00529         double split_point = (*node)->getAttributeAsDouble( "S", 0.0 );
00530         my_node.split_point = (float)split_point;
00531       }
00532       else if ( status == TERMINAL ) {
00533 
00534         my_node.status = TERMINAL;
00535         int label = (*node)->getAttributeAsInt( "V", 0 );
00536         my_node.label = uchar(label);
00537       }
00538       else {
00539 
00540         continue;
00541       }
00542 
00543       my_tree->add_node( my_node );
00544     }
00545 
00546     _trees.push_back( my_tree );
00547   }
00548 
00549   _initialized = true;
00550 
00551   _done = true;
00552 }