openModeller  Version 1.4.0
om_test.cpp
Go to the documentation of this file.
00001 #include <openmodeller/om.hh>
00002 #include <openmodeller/Exceptions.hh>
00003 #include <openmodeller/os_specific.hh>
00004 
00005 #include "getopts/getopts.h"
00006 
00007 #include "om_cmd_utils.hh"
00008 
00009 #include <fstream>   // file I/O for XML
00010 #include <sstream>   // ostringstream datatype
00011 #include <stdio.h>   // file I/O for log
00012 #include <time.h>    // used to limit the number of times that the progress is written to a file
00013 #include <string>    // string library
00014 #include <stdexcept> // try/catch
00015 
00016 using namespace std;
00017 
00019 int main( int argc, char **argv ) {
00020 
00021   Options opts;
00022   int option;
00023 
00024   // command-line parameters (short name, long name, description, take args)
00025   opts.addOption( "v", "version"       , "Display version info"                        , false );
00026   opts.addOption( "r", "xml-req"       , "(option 1) Test request file in XML"         , true );
00027   opts.addOption( "o", "model"         , "(option 2) Serialized model file"            , true );
00028   opts.addOption( "p", "points"        , "(option 2) TAB-delimited file with points"   , true );
00029   opts.addOption( "" , "calc-matrix"   , "Calculate confusion matrix"                  , false );
00030   opts.addOption( "t", "threshold"     , "Confusion matrix threshold"                  , true );
00031   opts.addOption( "" , "ignore-abs"    , "Ignore absences for the confusion matrix"    , false );
00032   opts.addOption( "" , "calc-roc"      , "Calculate ROC curve"                         , false );
00033   opts.addOption( "n", "resolution"    , "Number of points in the ROC curve"           , true );
00034   opts.addOption( "b", "num-background", "Number of background points for the ROC curve when there are no absences", true );
00035   opts.addOption( "e", "max-omission"  , "Calculate ROC partial area ratio given the maximum omission", true );
00036   opts.addOption( "" , "abs-background", "Use absences as background in ROC curve"   , false );
00037   opts.addOption( "s", "result"      , "File to store test result in XML"            , true );
00038   opts.addOption( "", "log-level"    , "Set the log level (debug, warn, info, error)", true );
00039   opts.addOption( "", "log-file"     , "Log file"                                    , true );
00040   opts.addOption( "" , "prog-file"   , "File to store test progress"                 , true );
00041   opts.addOption( "c", "config-file" , "Configuration file for openModeller"         , true );
00042 
00043   std::string log_level("info");
00044   std::string request_file;
00045   std::string model_file;
00046   std::string points_file;
00047   bool calc_matrix = false;
00048   std::string threshold_string("");
00049   double threshold = CONF_MATRIX_DEFAULT_THRESHOLD;
00050   bool ignore_abs = false;
00051   bool calc_roc = false;
00052   std::string resolution_string("");
00053   int resolution = -1;
00054   std::string num_background_string("");
00055   int num_background = -1;
00056   std::string max_omission_string("");
00057   double max_omission = 1.0;
00058   bool abs_background = false;
00059   std::string result_file;
00060   std::string log_file;
00061   std::string progress_file;
00062   std::string config_file;
00063 
00064   if ( ! opts.parse( argc, argv ) ) {
00065 
00066     opts.showHelp( argv[0] );
00067   }
00068 
00069   // Set up any related external resources
00070   setupExternalResources();
00071 
00072   OpenModeller om;
00073 
00074   while ( ( option = opts.cycle() ) >= 0 ) {
00075 
00076     switch ( option ) {
00077 
00078       case 0:
00079         printf( "om_test %s\n", om.getVersion().c_str() );
00080         printf("This is free software; see the source for copying conditions. There is NO\n");
00081         printf("warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n");
00082         exit(0);
00083         break;
00084       case 1:
00085         request_file = opts.getArgs( option );
00086         break;
00087       case 2:
00088         model_file = opts.getArgs( option );
00089         break;
00090       case 3:
00091         points_file = opts.getArgs( option );
00092         break;
00093       case 4:
00094         calc_matrix = true;
00095         break;
00096       case 5:
00097         threshold_string = opts.getArgs( option );
00098         break;
00099       case 6:
00100         ignore_abs = true;
00101         break;
00102       case 7:
00103         calc_roc = true;
00104         break;
00105       case 8:
00106         resolution_string = opts.getArgs( option );
00107         break;
00108       case 9:
00109         num_background_string = opts.getArgs( option );
00110         break;
00111       case 10:
00112         max_omission_string = opts.getArgs( option );
00113         break;
00114       case 11:
00115         abs_background = true;
00116         break;
00117       case 12:
00118         result_file = opts.getArgs( option );
00119         break;
00120       case 13:
00121         log_level = opts.getArgs( option );
00122         break;
00123       case 14:
00124         log_file = opts.getArgs( option );
00125         break;
00126       case 15:
00127         progress_file = opts.getArgs( option );
00128         break;
00129       case 16:
00130         config_file = opts.getArgs( option );
00131         break;
00132       default:
00133         break;
00134     }
00135   }
00136 
00137   // om configuration
00138   if ( ! config_file.empty() ) { 
00139 
00140     Settings::loadConfig( config_file );
00141   }
00142 
00143   // Initialize progress data if user wants to track progress
00144   progress_data prog_data;
00145 
00146   if ( ! progress_file.empty() ) { 
00147 
00148     prog_data.file_name = progress_file;
00149 
00150     time( &prog_data.timestamp );
00151 
00152     prog_data.progress = -1.0; // queued
00153 
00154     // Always create initial file with progress 0
00155     progressFileCallback( 0.0, &prog_data );
00156   }
00157 
00158   // Log stuff
00159 
00160   Log::Level level_code = getLogLevel( log_level );
00161 
00162   if ( ! log_file.empty() ) {
00163 
00164     Log::instance()->set( level_code, log_file, "" );
00165   }
00166   else {
00167  
00168     // Just set the level - things will go to stderr
00169     Log::instance()->setLevel( level_code );
00170   }
00171 
00172   // Check parameters
00173 
00174   if ( ! request_file.empty() ) {
00175 
00176     if ( ! model_file.empty() ) {
00177 
00178       Log::instance()->warn( "Model file parameter will be ignored (using XML request instead)\n" );
00179     }
00180     if ( ! points_file.empty() ) {
00181 
00182       printf( "Points file parameter will be ignored (using XML request instead)\n");
00183     }
00184     if ( calc_roc ) {
00185 
00186       printf( "Parameter to calculate ROC curve will be ignored (when using XML request you should specify it in the XML)\n");
00187     }
00188     if ( calc_matrix ) {
00189 
00190       printf( "Parameter to calculate confusion matrix will be ignored (when using XML request you should specify it in the XML)\n");
00191     }
00192   }
00193   else if ( ( ! model_file.empty() ) && ! points_file.empty() ) {
00194 
00195     // Custom threshold
00196     if ( ! threshold_string.empty() ) {
00197 
00198       threshold = atof( threshold_string.c_str() );
00199     }
00200 
00201     // Custom resolution
00202     if ( ! resolution_string.empty() ) {
00203 
00204       resolution = atoi( resolution_string.c_str() );
00205     }
00206 
00207     // Custom number of background points
00208     if ( ! num_background_string.empty() ) {
00209 
00210       num_background = atoi( num_background_string.c_str() );
00211     }
00212 
00213     // Custom max omission
00214     if ( ! max_omission_string.empty() ) {
00215 
00216       max_omission = atof( max_omission_string.c_str() );
00217     }
00218   }
00219   else {
00220 
00221     printf( "Please specify either a test request file in XML or a serialized model and a TAB-delimited file with the points to be tested\n");
00222 
00223     // If user is tracking progress
00224     if ( ! progress_file.empty() ) { 
00225 
00226       // -2 means aborted
00227       progressFileCallback( -2.0, &prog_data );
00228     }
00229 
00230     exit(-1);
00231   }
00232 
00233   if ( ! calc_matrix ) {
00234 
00235     if ( ! threshold_string.empty() ) {
00236 
00237       Log::instance()->warn( "Ignoring threshold - option only available with confusion matrix\n" );
00238     }
00239     if ( ignore_abs ) {
00240 
00241       Log::instance()->warn( "Ignoring ignore-abs - option only available with confusion matrix\n" );
00242     }
00243   }
00244 
00245   if ( ! calc_roc ) {
00246 
00247     if ( ! resolution_string.empty() ) {
00248 
00249       Log::instance()->warn( "Ignoring resolution - option only available with ROC curve\n" );
00250     }
00251     if ( ! max_omission_string.empty() ) {
00252 
00253       Log::instance()->warn( "Ignoring maximum omission - option only available with ROC curve\n" );
00254     }
00255     if ( abs_background ) {
00256 
00257       Log::instance()->warn( "Ignoring abs-background - option only available with ROC curve\n" );
00258     }
00259     if ( ! num_background_string.empty() ) {
00260 
00261       Log::instance()->warn( "Ignoring number of background points - option only available with ROC curve\n" );
00262     }
00263   }
00264 
00265   // Real work
00266 
00267   try {
00268 
00269     // Load algorithms and instantiate controller class
00270     AlgorithmFactory::searchDefaultDirs();
00271 
00272     SamplerPtr sampler;
00273 
00274     AlgorithmPtr alg;
00275 
00276     if ( ! request_file.empty() ) {
00277 
00278       // Loading input from XML request
00279 
00280       ConfigurationPtr input = Configuration::readXml( request_file.c_str() );
00281 
00282       alg = AlgorithmFactory::newAlgorithm( input->getSubsection( "Algorithm" ) );
00283 
00284       sampler = createSampler( input->getSubsection( "Sampler" ) );
00285 
00286       try {
00287 
00288         ConfigurationPtr statistics_param = input->getSubsection( "Statistics" );
00289 
00290         try {
00291 
00292           ConfigurationPtr matrix_param = statistics_param->getSubsection( "ConfusionMatrix" );
00293 
00294           calc_matrix = true;
00295 
00296           threshold = matrix_param->getAttributeAsDouble( "Threshold", CONF_MATRIX_DEFAULT_THRESHOLD );
00297 
00298           int ignore_absences_int = matrix_param->getAttributeAsInt( "IgnoreAbsences", 0 );
00299 
00300           if ( ignore_absences_int > 0 ) {
00301 
00302               ignore_abs = true;
00303           }
00304         }
00305         catch( SubsectionNotFound& e ) {
00306 
00307           UNUSED(e);
00308         }
00309 
00310         ConfigurationPtr roc_param;
00311 
00312         try {
00313 
00314           ConfigurationPtr roc_param = statistics_param->getSubsection( "RocCurve" );
00315 
00316           calc_roc = true;
00317 
00318           resolution = roc_param->getAttributeAsInt( "Resolution", -1 );
00319 
00320           num_background = roc_param->getAttributeAsInt( "BackgroundPoints", -1 );
00321 
00322           max_omission = roc_param->getAttributeAsDouble( "MaxOmission", 1.0 );
00323 
00324           int use_absences_as_background_int = roc_param->getAttributeAsInt( "UseAbsencesAsBackground", 0 );
00325 
00326           if ( use_absences_as_background_int > 0 ) {
00327 
00328             abs_background = true;
00329           }
00330         }
00331         catch( SubsectionNotFound& e ) {
00332 
00333           UNUSED(e);
00334         }
00335       }
00336       catch( SubsectionNotFound& e ) {
00337 
00338         // For backwards compatibility, calculate matrix and ROC if 
00339         // <Statistics> is not present
00340         calc_matrix = true;
00341         calc_roc = true;
00342         UNUSED(e);
00343       }
00344     }
00345     else {
00346 
00347       // Loading input from serialized model + TAB-delimited points file
00348 
00349       ConfigurationPtr input = Configuration::readXml( model_file.c_str() );
00350 
00351       alg = AlgorithmFactory::newAlgorithm( input->getSubsection( "Algorithm" ) );
00352 
00353       Log::instance()->debug( "Loading training sampler to get layers, label and spatial reference\n" );
00354 
00355       SamplerPtr training_sampler = createSampler( input->getSubsection( "Sampler" ) );
00356 
00357       // Get environment from training sampler
00358       EnvironmentPtr env = training_sampler->getEnvironment();
00359 
00360       // Get label and spatial reference from presence points of the training sampler
00361       OccurrencesPtr training_presences = training_sampler->getPresences();
00362 
00363       std::string label( training_presences->label() );
00364       std::string spatial_ref( training_presences->coordSystem() );
00365 
00366       Log::instance()->debug( "Loading test points %s %s\n", label.c_str(), spatial_ref.c_str() );
00367 
00368       OccurrencesReader* oc_reader = OccurrencesFactory::instance().create( points_file.c_str(), spatial_ref.c_str() );
00369 
00370       OccurrencesPtr presences = oc_reader->getPresences( label.c_str() );
00371       OccurrencesPtr absences = oc_reader->getAbsences( label.c_str() );
00372 
00373       delete oc_reader;
00374 
00375       // Create new sampler for test points
00376       sampler = createSampler( env, presences, absences );
00377     }
00378 
00379     if ( ! alg->done() ) {
00380 
00381       Log::instance()->error( "No model could be found as part of the specified algorithm. Aborting.\n");
00382 
00383       // If user is tracking progress
00384       if ( ! progress_file.empty() ) { 
00385 
00386         // -2 means aborted
00387         progressFileCallback( -2.0, &prog_data );
00388       }
00389 
00390       exit(-1);
00391     }
00392 
00393     // Run tests
00394 
00395     Log::instance()->debug( "Starting tests\n" );
00396 
00397     int num_presences = sampler->numPresence();
00398     int num_absences = sampler->numAbsence();
00399 
00400     ConfusionMatrix matrix;
00401 
00402     // Confusion matrix can only be calculated with presence and/or absence points
00403     if ( calc_matrix && ( num_presences || num_absences ) ) {
00404 
00405       if ( threshold < 0.0 ) {
00406 
00407         matrix.setLowestTrainingThreshold( alg->getModel(), sampler );
00408 
00409         threshold = matrix.getThreshold();
00410       }
00411 
00412       matrix.reset( threshold, ignore_abs );
00413 
00414       matrix.calculate( alg->getModel(), sampler );
00415     }
00416 
00417     RocCurve roc_curve;
00418 
00419     // ROC curve can only be calculated with presence points
00420     // No absence points will force background points to be generated
00421     if ( calc_roc && num_presences ) {
00422 
00423       resolution = (resolution <= 0) ? ROC_DEFAULT_RESOLUTION : resolution;
00424 
00425       if ( abs_background ) {
00426 
00427         roc_curve.initialize( resolution, true );
00428       }
00429       else {
00430 
00431         if ( num_background > 0 ) {
00432 
00433           roc_curve.initialize( resolution, num_background );
00434         }   
00435         else {
00436 
00437           roc_curve.initialize( resolution );
00438         }
00439       }
00440 
00441       roc_curve.calculate( alg->getModel(), sampler );
00442     }
00443 
00444     if ( calc_matrix && ! num_presences ) {
00445 
00446       Log::instance()->warn( "No presence points - ROC curve and omission error won't be calculated\n" );
00447     }
00448 
00449     if ( calc_matrix && ! num_absences ) {
00450 
00451       Log::instance()->warn( "No absence points - commission error won't be calculated\n" );
00452     }
00453 
00454 
00455     if ( calc_roc && ! num_presences ) {
00456 
00457       Log::instance()->warn( "No presence points - ROC curve won't be calculated\n" );
00458     }
00459 
00460     if ( calc_matrix ) {
00461 
00462       if ( num_presences || num_absences ) {
00463 
00464         Log::instance()->info( "\nModel statistics\n" );
00465         Log::instance()->info( "Accuracy:          %7.2f%%\n", matrix.getAccuracy() * 100 );
00466       }
00467 
00468       if ( num_presences ) {
00469 
00470         int omissions = matrix.getValue(0.0, 1.0);
00471         int total     = omissions + matrix.getValue(1.0, 1.0);
00472 
00473         Log::instance()->info( "Omission error:    %7.2f%% (%d/%d)\n", matrix.getOmissionError() * 100, omissions, total );
00474       }
00475 
00476       if ( num_absences ) {
00477 
00478         int commissions = matrix.getValue(1.0, 0.0);
00479         int total       = commissions + matrix.getValue(0.0, 0.0);
00480 
00481         Log::instance()->info( "Commission error:  %7.2f%% (%d/%d)\n", matrix.getCommissionError() * 100, commissions, total );
00482       }
00483     }
00484 
00485     if ( calc_roc ) {
00486 
00487       if ( num_presences ) {
00488 
00489         Log::instance()->info( "AUC:               %7.2f\n", roc_curve.getTotalArea() );
00490 
00491         if ( max_omission < 1.0 ) {
00492 
00493           Log::instance()->info( "Ratio:             %7.2f\n", roc_curve.getPartialAreaRatio( max_omission ) );
00494         }
00495       }
00496     }
00497 
00498     ConfigurationPtr output( new ConfigurationImpl("Statistics") );
00499 
00500     bool no_statistics = true;
00501 
00502     if ( calc_matrix && matrix.ready() ) {
00503 
00504       ConfigurationPtr cm_config( matrix.getConfiguration() );
00505 
00506       output->addSubsection( cm_config );
00507 
00508       no_statistics = false;
00509     }
00510 
00511     if ( calc_roc && roc_curve.ready() ) {
00512 
00513       ConfigurationPtr roc_config( roc_curve.getConfiguration() );
00514 
00515       output->addSubsection( roc_config );
00516 
00517       no_statistics = false;
00518     }
00519 
00520     if ( no_statistics )
00521     {
00522       Log::instance()->warn( "No statistics calculated\n" );
00523     }
00524 
00525     std::ostringstream test_output;
00526 
00527     Configuration::writeXml( output, test_output );
00528 
00529     std::cerr << flush;
00530 
00531     // Write test output to file, if requested
00532     if ( ! result_file.empty() ) {
00533 
00534       ofstream file( result_file.c_str() );
00535       file << test_output.str();
00536       file.close();
00537     }
00538     else {
00539 
00540       // Otherwise send it to stdout
00541       std::cout << test_output.str().c_str() << endl << flush;
00542     }
00543 
00544     // If user wants to track progress
00545     if ( ! progress_file.empty() ) { 
00546 
00547       // Indicate that the job is finished
00548       progressFileCallback( 1.0, &prog_data );
00549     }
00550   }
00551   catch ( runtime_error e ) {
00552 
00553     // If user is tracking progress
00554     if ( ! progress_file.empty() ) { 
00555 
00556       // -2 means aborted
00557       progressFileCallback( -2.0, &prog_data );
00558     }
00559 
00560     printf( "om_test aborted: %s\n", e.what() );
00561   }
00562 }