openModeller
Version 1.4.0
|
00001 #include <openmodeller/om.hh> 00002 #include <openmodeller/Exceptions.hh> 00003 #include <openmodeller/os_specific.hh> 00004 00005 #include "getopts/getopts.h" 00006 00007 #include "om_cmd_utils.hh" 00008 00009 #include <fstream> // file I/O for XML 00010 #include <sstream> // ostringstream datatype 00011 #include <stdio.h> // file I/O for log 00012 #include <time.h> // used to limit the number of times that the progress is written to a file 00013 #include <string> // string library 00014 #include <stdexcept> // try/catch 00015 00016 using namespace std; 00017 00019 int main( int argc, char **argv ) { 00020 00021 Options opts; 00022 int option; 00023 00024 // command-line parameters (short name, long name, description, take args) 00025 opts.addOption( "v", "version" , "Display version info" , false ); 00026 opts.addOption( "r", "xml-req" , "(option 1) Test request file in XML" , true ); 00027 opts.addOption( "o", "model" , "(option 2) Serialized model file" , true ); 00028 opts.addOption( "p", "points" , "(option 2) TAB-delimited file with points" , true ); 00029 opts.addOption( "" , "calc-matrix" , "Calculate confusion matrix" , false ); 00030 opts.addOption( "t", "threshold" , "Confusion matrix threshold" , true ); 00031 opts.addOption( "" , "ignore-abs" , "Ignore absences for the confusion matrix" , false ); 00032 opts.addOption( "" , "calc-roc" , "Calculate ROC curve" , false ); 00033 opts.addOption( "n", "resolution" , "Number of points in the ROC curve" , true ); 00034 opts.addOption( "b", "num-background", "Number of background points for the ROC curve when there are no absences", true ); 00035 opts.addOption( "e", "max-omission" , "Calculate ROC partial area ratio given the maximum omission", true ); 00036 opts.addOption( "" , "abs-background", "Use absences as background in ROC curve" , false ); 00037 opts.addOption( "s", "result" , "File to store test result in XML" , true ); 00038 opts.addOption( "", "log-level" , "Set the log level (debug, warn, info, error)", true ); 00039 opts.addOption( "", "log-file" , "Log file" , true ); 00040 opts.addOption( "" , "prog-file" , "File to store test progress" , true ); 00041 opts.addOption( "c", "config-file" , "Configuration file for openModeller" , true ); 00042 00043 std::string log_level("info"); 00044 std::string request_file; 00045 std::string model_file; 00046 std::string points_file; 00047 bool calc_matrix = false; 00048 std::string threshold_string(""); 00049 double threshold = CONF_MATRIX_DEFAULT_THRESHOLD; 00050 bool ignore_abs = false; 00051 bool calc_roc = false; 00052 std::string resolution_string(""); 00053 int resolution = -1; 00054 std::string num_background_string(""); 00055 int num_background = -1; 00056 std::string max_omission_string(""); 00057 double max_omission = 1.0; 00058 bool abs_background = false; 00059 std::string result_file; 00060 std::string log_file; 00061 std::string progress_file; 00062 std::string config_file; 00063 00064 if ( ! opts.parse( argc, argv ) ) { 00065 00066 opts.showHelp( argv[0] ); 00067 } 00068 00069 // Set up any related external resources 00070 setupExternalResources(); 00071 00072 OpenModeller om; 00073 00074 while ( ( option = opts.cycle() ) >= 0 ) { 00075 00076 switch ( option ) { 00077 00078 case 0: 00079 printf( "om_test %s\n", om.getVersion().c_str() ); 00080 printf("This is free software; see the source for copying conditions. There is NO\n"); 00081 printf("warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n"); 00082 exit(0); 00083 break; 00084 case 1: 00085 request_file = opts.getArgs( option ); 00086 break; 00087 case 2: 00088 model_file = opts.getArgs( option ); 00089 break; 00090 case 3: 00091 points_file = opts.getArgs( option ); 00092 break; 00093 case 4: 00094 calc_matrix = true; 00095 break; 00096 case 5: 00097 threshold_string = opts.getArgs( option ); 00098 break; 00099 case 6: 00100 ignore_abs = true; 00101 break; 00102 case 7: 00103 calc_roc = true; 00104 break; 00105 case 8: 00106 resolution_string = opts.getArgs( option ); 00107 break; 00108 case 9: 00109 num_background_string = opts.getArgs( option ); 00110 break; 00111 case 10: 00112 max_omission_string = opts.getArgs( option ); 00113 break; 00114 case 11: 00115 abs_background = true; 00116 break; 00117 case 12: 00118 result_file = opts.getArgs( option ); 00119 break; 00120 case 13: 00121 log_level = opts.getArgs( option ); 00122 break; 00123 case 14: 00124 log_file = opts.getArgs( option ); 00125 break; 00126 case 15: 00127 progress_file = opts.getArgs( option ); 00128 break; 00129 case 16: 00130 config_file = opts.getArgs( option ); 00131 break; 00132 default: 00133 break; 00134 } 00135 } 00136 00137 // om configuration 00138 if ( ! config_file.empty() ) { 00139 00140 Settings::loadConfig( config_file ); 00141 } 00142 00143 // Initialize progress data if user wants to track progress 00144 progress_data prog_data; 00145 00146 if ( ! progress_file.empty() ) { 00147 00148 prog_data.file_name = progress_file; 00149 00150 time( &prog_data.timestamp ); 00151 00152 prog_data.progress = -1.0; // queued 00153 00154 // Always create initial file with progress 0 00155 progressFileCallback( 0.0, &prog_data ); 00156 } 00157 00158 // Log stuff 00159 00160 Log::Level level_code = getLogLevel( log_level ); 00161 00162 if ( ! log_file.empty() ) { 00163 00164 Log::instance()->set( level_code, log_file, "" ); 00165 } 00166 else { 00167 00168 // Just set the level - things will go to stderr 00169 Log::instance()->setLevel( level_code ); 00170 } 00171 00172 // Check parameters 00173 00174 if ( ! request_file.empty() ) { 00175 00176 if ( ! model_file.empty() ) { 00177 00178 Log::instance()->warn( "Model file parameter will be ignored (using XML request instead)\n" ); 00179 } 00180 if ( ! points_file.empty() ) { 00181 00182 printf( "Points file parameter will be ignored (using XML request instead)\n"); 00183 } 00184 if ( calc_roc ) { 00185 00186 printf( "Parameter to calculate ROC curve will be ignored (when using XML request you should specify it in the XML)\n"); 00187 } 00188 if ( calc_matrix ) { 00189 00190 printf( "Parameter to calculate confusion matrix will be ignored (when using XML request you should specify it in the XML)\n"); 00191 } 00192 } 00193 else if ( ( ! model_file.empty() ) && ! points_file.empty() ) { 00194 00195 // Custom threshold 00196 if ( ! threshold_string.empty() ) { 00197 00198 threshold = atof( threshold_string.c_str() ); 00199 } 00200 00201 // Custom resolution 00202 if ( ! resolution_string.empty() ) { 00203 00204 resolution = atoi( resolution_string.c_str() ); 00205 } 00206 00207 // Custom number of background points 00208 if ( ! num_background_string.empty() ) { 00209 00210 num_background = atoi( num_background_string.c_str() ); 00211 } 00212 00213 // Custom max omission 00214 if ( ! max_omission_string.empty() ) { 00215 00216 max_omission = atof( max_omission_string.c_str() ); 00217 } 00218 } 00219 else { 00220 00221 printf( "Please specify either a test request file in XML or a serialized model and a TAB-delimited file with the points to be tested\n"); 00222 00223 // If user is tracking progress 00224 if ( ! progress_file.empty() ) { 00225 00226 // -2 means aborted 00227 progressFileCallback( -2.0, &prog_data ); 00228 } 00229 00230 exit(-1); 00231 } 00232 00233 if ( ! calc_matrix ) { 00234 00235 if ( ! threshold_string.empty() ) { 00236 00237 Log::instance()->warn( "Ignoring threshold - option only available with confusion matrix\n" ); 00238 } 00239 if ( ignore_abs ) { 00240 00241 Log::instance()->warn( "Ignoring ignore-abs - option only available with confusion matrix\n" ); 00242 } 00243 } 00244 00245 if ( ! calc_roc ) { 00246 00247 if ( ! resolution_string.empty() ) { 00248 00249 Log::instance()->warn( "Ignoring resolution - option only available with ROC curve\n" ); 00250 } 00251 if ( ! max_omission_string.empty() ) { 00252 00253 Log::instance()->warn( "Ignoring maximum omission - option only available with ROC curve\n" ); 00254 } 00255 if ( abs_background ) { 00256 00257 Log::instance()->warn( "Ignoring abs-background - option only available with ROC curve\n" ); 00258 } 00259 if ( ! num_background_string.empty() ) { 00260 00261 Log::instance()->warn( "Ignoring number of background points - option only available with ROC curve\n" ); 00262 } 00263 } 00264 00265 // Real work 00266 00267 try { 00268 00269 // Load algorithms and instantiate controller class 00270 AlgorithmFactory::searchDefaultDirs(); 00271 00272 SamplerPtr sampler; 00273 00274 AlgorithmPtr alg; 00275 00276 if ( ! request_file.empty() ) { 00277 00278 // Loading input from XML request 00279 00280 ConfigurationPtr input = Configuration::readXml( request_file.c_str() ); 00281 00282 alg = AlgorithmFactory::newAlgorithm( input->getSubsection( "Algorithm" ) ); 00283 00284 sampler = createSampler( input->getSubsection( "Sampler" ) ); 00285 00286 try { 00287 00288 ConfigurationPtr statistics_param = input->getSubsection( "Statistics" ); 00289 00290 try { 00291 00292 ConfigurationPtr matrix_param = statistics_param->getSubsection( "ConfusionMatrix" ); 00293 00294 calc_matrix = true; 00295 00296 threshold = matrix_param->getAttributeAsDouble( "Threshold", CONF_MATRIX_DEFAULT_THRESHOLD ); 00297 00298 int ignore_absences_int = matrix_param->getAttributeAsInt( "IgnoreAbsences", 0 ); 00299 00300 if ( ignore_absences_int > 0 ) { 00301 00302 ignore_abs = true; 00303 } 00304 } 00305 catch( SubsectionNotFound& e ) { 00306 00307 UNUSED(e); 00308 } 00309 00310 ConfigurationPtr roc_param; 00311 00312 try { 00313 00314 ConfigurationPtr roc_param = statistics_param->getSubsection( "RocCurve" ); 00315 00316 calc_roc = true; 00317 00318 resolution = roc_param->getAttributeAsInt( "Resolution", -1 ); 00319 00320 num_background = roc_param->getAttributeAsInt( "BackgroundPoints", -1 ); 00321 00322 max_omission = roc_param->getAttributeAsDouble( "MaxOmission", 1.0 ); 00323 00324 int use_absences_as_background_int = roc_param->getAttributeAsInt( "UseAbsencesAsBackground", 0 ); 00325 00326 if ( use_absences_as_background_int > 0 ) { 00327 00328 abs_background = true; 00329 } 00330 } 00331 catch( SubsectionNotFound& e ) { 00332 00333 UNUSED(e); 00334 } 00335 } 00336 catch( SubsectionNotFound& e ) { 00337 00338 // For backwards compatibility, calculate matrix and ROC if 00339 // <Statistics> is not present 00340 calc_matrix = true; 00341 calc_roc = true; 00342 UNUSED(e); 00343 } 00344 } 00345 else { 00346 00347 // Loading input from serialized model + TAB-delimited points file 00348 00349 ConfigurationPtr input = Configuration::readXml( model_file.c_str() ); 00350 00351 alg = AlgorithmFactory::newAlgorithm( input->getSubsection( "Algorithm" ) ); 00352 00353 Log::instance()->debug( "Loading training sampler to get layers, label and spatial reference\n" ); 00354 00355 SamplerPtr training_sampler = createSampler( input->getSubsection( "Sampler" ) ); 00356 00357 // Get environment from training sampler 00358 EnvironmentPtr env = training_sampler->getEnvironment(); 00359 00360 // Get label and spatial reference from presence points of the training sampler 00361 OccurrencesPtr training_presences = training_sampler->getPresences(); 00362 00363 std::string label( training_presences->label() ); 00364 std::string spatial_ref( training_presences->coordSystem() ); 00365 00366 Log::instance()->debug( "Loading test points %s %s\n", label.c_str(), spatial_ref.c_str() ); 00367 00368 OccurrencesReader* oc_reader = OccurrencesFactory::instance().create( points_file.c_str(), spatial_ref.c_str() ); 00369 00370 OccurrencesPtr presences = oc_reader->getPresences( label.c_str() ); 00371 OccurrencesPtr absences = oc_reader->getAbsences( label.c_str() ); 00372 00373 delete oc_reader; 00374 00375 // Create new sampler for test points 00376 sampler = createSampler( env, presences, absences ); 00377 } 00378 00379 if ( ! alg->done() ) { 00380 00381 Log::instance()->error( "No model could be found as part of the specified algorithm. Aborting.\n"); 00382 00383 // If user is tracking progress 00384 if ( ! progress_file.empty() ) { 00385 00386 // -2 means aborted 00387 progressFileCallback( -2.0, &prog_data ); 00388 } 00389 00390 exit(-1); 00391 } 00392 00393 // Run tests 00394 00395 Log::instance()->debug( "Starting tests\n" ); 00396 00397 int num_presences = sampler->numPresence(); 00398 int num_absences = sampler->numAbsence(); 00399 00400 ConfusionMatrix matrix; 00401 00402 // Confusion matrix can only be calculated with presence and/or absence points 00403 if ( calc_matrix && ( num_presences || num_absences ) ) { 00404 00405 if ( threshold < 0.0 ) { 00406 00407 matrix.setLowestTrainingThreshold( alg->getModel(), sampler ); 00408 00409 threshold = matrix.getThreshold(); 00410 } 00411 00412 matrix.reset( threshold, ignore_abs ); 00413 00414 matrix.calculate( alg->getModel(), sampler ); 00415 } 00416 00417 RocCurve roc_curve; 00418 00419 // ROC curve can only be calculated with presence points 00420 // No absence points will force background points to be generated 00421 if ( calc_roc && num_presences ) { 00422 00423 resolution = (resolution <= 0) ? ROC_DEFAULT_RESOLUTION : resolution; 00424 00425 if ( abs_background ) { 00426 00427 roc_curve.initialize( resolution, true ); 00428 } 00429 else { 00430 00431 if ( num_background > 0 ) { 00432 00433 roc_curve.initialize( resolution, num_background ); 00434 } 00435 else { 00436 00437 roc_curve.initialize( resolution ); 00438 } 00439 } 00440 00441 roc_curve.calculate( alg->getModel(), sampler ); 00442 } 00443 00444 if ( calc_matrix && ! num_presences ) { 00445 00446 Log::instance()->warn( "No presence points - ROC curve and omission error won't be calculated\n" ); 00447 } 00448 00449 if ( calc_matrix && ! num_absences ) { 00450 00451 Log::instance()->warn( "No absence points - commission error won't be calculated\n" ); 00452 } 00453 00454 00455 if ( calc_roc && ! num_presences ) { 00456 00457 Log::instance()->warn( "No presence points - ROC curve won't be calculated\n" ); 00458 } 00459 00460 if ( calc_matrix ) { 00461 00462 if ( num_presences || num_absences ) { 00463 00464 Log::instance()->info( "\nModel statistics\n" ); 00465 Log::instance()->info( "Accuracy: %7.2f%%\n", matrix.getAccuracy() * 100 ); 00466 } 00467 00468 if ( num_presences ) { 00469 00470 int omissions = matrix.getValue(0.0, 1.0); 00471 int total = omissions + matrix.getValue(1.0, 1.0); 00472 00473 Log::instance()->info( "Omission error: %7.2f%% (%d/%d)\n", matrix.getOmissionError() * 100, omissions, total ); 00474 } 00475 00476 if ( num_absences ) { 00477 00478 int commissions = matrix.getValue(1.0, 0.0); 00479 int total = commissions + matrix.getValue(0.0, 0.0); 00480 00481 Log::instance()->info( "Commission error: %7.2f%% (%d/%d)\n", matrix.getCommissionError() * 100, commissions, total ); 00482 } 00483 } 00484 00485 if ( calc_roc ) { 00486 00487 if ( num_presences ) { 00488 00489 Log::instance()->info( "AUC: %7.2f\n", roc_curve.getTotalArea() ); 00490 00491 if ( max_omission < 1.0 ) { 00492 00493 Log::instance()->info( "Ratio: %7.2f\n", roc_curve.getPartialAreaRatio( max_omission ) ); 00494 } 00495 } 00496 } 00497 00498 ConfigurationPtr output( new ConfigurationImpl("Statistics") ); 00499 00500 bool no_statistics = true; 00501 00502 if ( calc_matrix && matrix.ready() ) { 00503 00504 ConfigurationPtr cm_config( matrix.getConfiguration() ); 00505 00506 output->addSubsection( cm_config ); 00507 00508 no_statistics = false; 00509 } 00510 00511 if ( calc_roc && roc_curve.ready() ) { 00512 00513 ConfigurationPtr roc_config( roc_curve.getConfiguration() ); 00514 00515 output->addSubsection( roc_config ); 00516 00517 no_statistics = false; 00518 } 00519 00520 if ( no_statistics ) 00521 { 00522 Log::instance()->warn( "No statistics calculated\n" ); 00523 } 00524 00525 std::ostringstream test_output; 00526 00527 Configuration::writeXml( output, test_output ); 00528 00529 std::cerr << flush; 00530 00531 // Write test output to file, if requested 00532 if ( ! result_file.empty() ) { 00533 00534 ofstream file( result_file.c_str() ); 00535 file << test_output.str(); 00536 file.close(); 00537 } 00538 else { 00539 00540 // Otherwise send it to stdout 00541 std::cout << test_output.str().c_str() << endl << flush; 00542 } 00543 00544 // If user wants to track progress 00545 if ( ! progress_file.empty() ) { 00546 00547 // Indicate that the job is finished 00548 progressFileCallback( 1.0, &prog_data ); 00549 } 00550 } 00551 catch ( runtime_error e ) { 00552 00553 // If user is tracking progress 00554 if ( ! progress_file.empty() ) { 00555 00556 // -2 means aborted 00557 progressFileCallback( -2.0, &prog_data ); 00558 } 00559 00560 printf( "om_test aborted: %s\n", e.what() ); 00561 } 00562 }