openModeller  Version 1.5.0
om_test.cpp
Go to the documentation of this file.
1 #include <openmodeller/om.hh>
4 
5 #include "getopts/getopts.h"
6 
7 #include "om_cmd_utils.hh"
8 
9 #include <fstream> // file I/O for XML
10 #include <sstream> // ostringstream datatype
11 #include <stdio.h> // file I/O for log
12 #include <time.h> // used to limit the number of times that the progress is written to a file
13 #include <string> // string library
14 #include <stdexcept> // try/catch
15 
16 using namespace std;
17 
19 int main( int argc, char **argv ) {
20 
21  Options opts;
22  int option;
23 
24  // command-line parameters (short name, long name, description, take args)
25  opts.addOption( "v", "version" , "Display version info" , false );
26  opts.addOption( "r", "xml-req" , "(option 1) Test request file in XML" , true );
27  opts.addOption( "o", "model" , "(option 2) Serialized model file" , true );
28  opts.addOption( "p", "points" , "(option 2) TAB-delimited file with points" , true );
29  opts.addOption( "" , "calc-matrix" , "Calculate confusion matrix" , false );
30  opts.addOption( "t", "threshold" , "Confusion matrix threshold" , true );
31  opts.addOption( "" , "ignore-abs" , "Ignore absences for the confusion matrix" , false );
32  opts.addOption( "" , "calc-roc" , "Calculate ROC curve" , false );
33  opts.addOption( "n", "resolution" , "Number of points in the ROC curve" , true );
34  opts.addOption( "b", "num-background", "Number of background points for the ROC curve when there are no absences", true );
35  opts.addOption( "e", "max-omission" , "Calculate ROC partial area ratio given the maximum omission", true );
36  opts.addOption( "" , "abs-background", "Use absences as background in ROC curve" , false );
37  opts.addOption( "s", "result" , "File to store test result in XML" , true );
38  opts.addOption( "", "log-level" , "Set the log level (debug, warn, info, error)", true );
39  opts.addOption( "", "log-file" , "Log file" , true );
40  opts.addOption( "" , "prog-file" , "File to store test progress" , true );
41  opts.addOption( "c", "config-file" , "Configuration file for openModeller" , true );
42 
43  std::string log_level("info");
44  std::string request_file;
45  std::string model_file;
46  std::string points_file;
47  bool calc_matrix = false;
48  std::string threshold_string("");
49  double threshold = CONF_MATRIX_DEFAULT_THRESHOLD;
50  bool ignore_abs = false;
51  bool calc_roc = false;
52  std::string resolution_string("");
53  int resolution = -1;
54  std::string num_background_string("");
55  int num_background = -1;
56  std::string max_omission_string("");
57  double max_omission = 1.0;
58  bool abs_background = false;
59  std::string result_file;
60  std::string log_file;
61  std::string progress_file;
62  std::string config_file;
63 
64  if ( ! opts.parse( argc, argv ) ) {
65 
66  opts.showHelp( argv[0] );
67  }
68 
69  // Set up any related external resources
71 
72  OpenModeller om;
73 
74  while ( ( option = opts.cycle() ) >= 0 ) {
75 
76  switch ( option ) {
77 
78  case 0:
79  printf( "om_test %s\n", om.getVersion().c_str() );
80  printf("This is free software; see the source for copying conditions. There is NO\n");
81  printf("warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n");
82  exit(0);
83  break;
84  case 1:
85  request_file = opts.getArgs( option );
86  break;
87  case 2:
88  model_file = opts.getArgs( option );
89  break;
90  case 3:
91  points_file = opts.getArgs( option );
92  break;
93  case 4:
94  calc_matrix = true;
95  break;
96  case 5:
97  threshold_string = opts.getArgs( option );
98  break;
99  case 6:
100  ignore_abs = true;
101  break;
102  case 7:
103  calc_roc = true;
104  break;
105  case 8:
106  resolution_string = opts.getArgs( option );
107  break;
108  case 9:
109  num_background_string = opts.getArgs( option );
110  break;
111  case 10:
112  max_omission_string = opts.getArgs( option );
113  break;
114  case 11:
115  abs_background = true;
116  break;
117  case 12:
118  result_file = opts.getArgs( option );
119  break;
120  case 13:
121  log_level = opts.getArgs( option );
122  break;
123  case 14:
124  log_file = opts.getArgs( option );
125  break;
126  case 15:
127  progress_file = opts.getArgs( option );
128  break;
129  case 16:
130  config_file = opts.getArgs( option );
131  break;
132  default:
133  break;
134  }
135  }
136 
137  // Initialize progress data if user wants to track progress
138  progress_data prog_data;
139 
140  if ( ! progress_file.empty() ) {
141 
142  prog_data.file_name = progress_file;
143 
144  time( &prog_data.timestamp );
145 
146  prog_data.progress = -1.0; // queued
147 
148  // Always create initial file with progress 0
149  progressFileCallback( 0.0, &prog_data );
150  }
151 
152  // Log stuff
153 
154  Log::Level level_code = getLogLevel( log_level );
155 
156  if ( ! log_file.empty() ) {
157 
158  Log::instance()->set( level_code, log_file, "" );
159  }
160  else {
161 
162  // Just set the level - things will go to stderr
163  Log::instance()->setLevel( level_code );
164  }
165 
166  // om configuration
167  if ( ! config_file.empty() ) {
168 
169  Settings::loadConfig( config_file );
170  }
171 
172  // Check parameters
173 
174  if ( ! request_file.empty() ) {
175 
176  if ( ! model_file.empty() ) {
177 
178  Log::instance()->warn( "Model file parameter will be ignored (using XML request instead)\n" );
179  }
180  if ( ! points_file.empty() ) {
181 
182  printf( "Points file parameter will be ignored (using XML request instead)\n");
183  }
184  if ( calc_roc ) {
185 
186  printf( "Parameter to calculate ROC curve will be ignored (when using XML request you should specify it in the XML)\n");
187  }
188  if ( calc_matrix ) {
189 
190  printf( "Parameter to calculate confusion matrix will be ignored (when using XML request you should specify it in the XML)\n");
191  }
192  }
193  else if ( ( ! model_file.empty() ) && ! points_file.empty() ) {
194 
195  // Custom threshold
196  if ( ! threshold_string.empty() ) {
197 
198  threshold = atof( threshold_string.c_str() );
199  }
200 
201  // Custom resolution
202  if ( ! resolution_string.empty() ) {
203 
204  resolution = atoi( resolution_string.c_str() );
205  }
206 
207  // Custom number of background points
208  if ( ! num_background_string.empty() ) {
209 
210  num_background = atoi( num_background_string.c_str() );
211  }
212 
213  // Custom max omission
214  if ( ! max_omission_string.empty() ) {
215 
216  max_omission = atof( max_omission_string.c_str() );
217  }
218  }
219  else {
220 
221  printf( "Please specify either a test request file in XML or a serialized model and a TAB-delimited file with the points to be tested\n");
222 
223  // If user is tracking progress
224  if ( ! progress_file.empty() ) {
225 
226  // -2 means aborted
227  progressFileCallback( -2.0, &prog_data );
228  }
229 
230  exit(-1);
231  }
232 
233  if ( ! calc_matrix ) {
234 
235  if ( ! threshold_string.empty() ) {
236 
237  Log::instance()->warn( "Ignoring threshold - option only available with confusion matrix\n" );
238  }
239  if ( ignore_abs ) {
240 
241  Log::instance()->warn( "Ignoring ignore-abs - option only available with confusion matrix\n" );
242  }
243  }
244 
245  if ( ! calc_roc ) {
246 
247  if ( ! resolution_string.empty() ) {
248 
249  Log::instance()->warn( "Ignoring resolution - option only available with ROC curve\n" );
250  }
251  if ( ! max_omission_string.empty() ) {
252 
253  Log::instance()->warn( "Ignoring maximum omission - option only available with ROC curve\n" );
254  }
255  if ( abs_background ) {
256 
257  Log::instance()->warn( "Ignoring abs-background - option only available with ROC curve\n" );
258  }
259  if ( ! num_background_string.empty() ) {
260 
261  Log::instance()->warn( "Ignoring number of background points - option only available with ROC curve\n" );
262  }
263  }
264 
265  // Real work
266 
267  try {
268 
269  // Load algorithms and instantiate controller class
271 
272  SamplerPtr sampler;
273 
274  AlgorithmPtr alg;
275 
276  if ( ! request_file.empty() ) {
277 
278  // Loading input from XML request
279 
280  ConfigurationPtr input = Configuration::readXml( request_file.c_str() );
281 
282  alg = AlgorithmFactory::newAlgorithm( input->getSubsection( "Algorithm" ) );
283 
284  sampler = createSampler( input->getSubsection( "Sampler" ) );
285 
286  try {
287 
288  ConfigurationPtr statistics_param = input->getSubsection( "Statistics" );
289 
290  try {
291 
292  ConfigurationPtr matrix_param = statistics_param->getSubsection( "ConfusionMatrix" );
293 
294  calc_matrix = true;
295 
296  threshold = matrix_param->getAttributeAsDouble( "Threshold", CONF_MATRIX_DEFAULT_THRESHOLD );
297 
298  int ignore_absences_int = matrix_param->getAttributeAsInt( "IgnoreAbsences", 0 );
299 
300  if ( ignore_absences_int > 0 ) {
301 
302  ignore_abs = true;
303  }
304  }
305  catch( SubsectionNotFound& e ) {
306 
307  UNUSED(e);
308  }
309 
310  ConfigurationPtr roc_param;
311 
312  try {
313 
314  ConfigurationPtr roc_param = statistics_param->getSubsection( "RocCurve" );
315 
316  calc_roc = true;
317 
318  resolution = roc_param->getAttributeAsInt( "Resolution", -1 );
319 
320  num_background = roc_param->getAttributeAsInt( "BackgroundPoints", -1 );
321 
322  max_omission = roc_param->getAttributeAsDouble( "MaxOmission", 1.0 );
323 
324  int use_absences_as_background_int = roc_param->getAttributeAsInt( "UseAbsencesAsBackground", 0 );
325 
326  if ( use_absences_as_background_int > 0 ) {
327 
328  abs_background = true;
329  }
330  }
331  catch( SubsectionNotFound& e ) {
332 
333  UNUSED(e);
334  }
335  }
336  catch( SubsectionNotFound& e ) {
337 
338  // For backwards compatibility, calculate matrix and ROC if
339  // <Statistics> is not present
340  calc_matrix = true;
341  calc_roc = true;
342  UNUSED(e);
343  }
344  }
345  else {
346 
347  // Loading input from serialized model + TAB-delimited points file
348 
349  ConfigurationPtr input = Configuration::readXml( model_file.c_str() );
350 
351  alg = AlgorithmFactory::newAlgorithm( input->getSubsection( "Algorithm" ) );
352 
353  Log::instance()->debug( "Loading training sampler to get layers, label and spatial reference\n" );
354 
355  SamplerPtr training_sampler = createSampler( input->getSubsection( "Sampler" ) );
356 
357  // Get environment from training sampler
358  EnvironmentPtr env = training_sampler->getEnvironment();
359 
360  // Get label and spatial reference from presence points of the training sampler
361  OccurrencesPtr training_presences = training_sampler->getPresences();
362 
363  std::string label( training_presences->label() );
364  std::string spatial_ref( training_presences->coordSystem() );
365 
366  Log::instance()->debug( "Loading test points %s %s\n", label.c_str(), spatial_ref.c_str() );
367 
368  OccurrencesReader* oc_reader = OccurrencesFactory::instance().create( points_file.c_str(), spatial_ref.c_str() );
369 
370  OccurrencesPtr presences = oc_reader->getPresences( label.c_str() );
371  OccurrencesPtr absences = oc_reader->getAbsences( label.c_str() );
372 
373  delete oc_reader;
374 
375  // Create new sampler for test points
376  sampler = createSampler( env, presences, absences );
377  }
378 
379  if ( ! alg->done() ) {
380 
381  Log::instance()->error( "No model could be found as part of the specified algorithm. Aborting.\n");
382 
383  // If user is tracking progress
384  if ( ! progress_file.empty() ) {
385 
386  // -2 means aborted
387  progressFileCallback( -2.0, &prog_data );
388  }
389 
390  exit(-1);
391  }
392 
393  // Run tests
394 
395  Log::instance()->debug( "Starting tests\n" );
396 
397  int num_presences = sampler->numPresence();
398  int num_absences = sampler->numAbsence();
399 
400  ConfusionMatrix matrix;
401 
402  // Confusion matrix can only be calculated with presence and/or absence points
403  if ( calc_matrix && ( num_presences || num_absences ) ) {
404 
405  if ( threshold < 0.0 ) {
406 
407  matrix.setLowestTrainingThreshold( alg->getModel(), sampler );
408 
409  threshold = matrix.getThreshold();
410  }
411 
412  matrix.reset( threshold, ignore_abs );
413 
414  matrix.calculate( alg->getModel(), sampler );
415  }
416 
417  RocCurve roc_curve;
418 
419  // ROC curve can only be calculated with presence points
420  // No absence points will force background points to be generated
421  if ( calc_roc && num_presences ) {
422 
423  resolution = (resolution <= 0) ? ROC_DEFAULT_RESOLUTION : resolution;
424 
425  if ( abs_background ) {
426 
427  roc_curve.initialize( resolution, true );
428  }
429  else {
430 
431  if ( num_background > 0 ) {
432 
433  roc_curve.initialize( resolution, num_background );
434  }
435  else {
436 
437  roc_curve.initialize( resolution );
438  }
439  }
440 
441  roc_curve.calculate( alg->getModel(), sampler );
442  }
443 
444  if ( calc_matrix && ! num_presences ) {
445 
446  Log::instance()->warn( "No presence points - ROC curve and omission error won't be calculated\n" );
447  }
448 
449  if ( calc_matrix && ! num_absences ) {
450 
451  Log::instance()->warn( "No absence points - commission error won't be calculated\n" );
452  }
453 
454 
455  if ( calc_roc && ! num_presences ) {
456 
457  Log::instance()->warn( "No presence points - ROC curve won't be calculated\n" );
458  }
459 
460  if ( calc_matrix ) {
461 
462  if ( num_presences || num_absences ) {
463 
464  Log::instance()->info( "\nModel statistics\n" );
465  Log::instance()->info( "Accuracy: %7.2f%%\n", matrix.getAccuracy() * 100 );
466  }
467 
468  if ( num_presences ) {
469 
470  int omissions = matrix.getValue(0.0, 1.0);
471  int total = omissions + matrix.getValue(1.0, 1.0);
472 
473  Log::instance()->info( "Omission error: %7.2f%% (%d/%d)\n", matrix.getOmissionError() * 100, omissions, total );
474  }
475 
476  if ( num_absences ) {
477 
478  int commissions = matrix.getValue(1.0, 0.0);
479  int total = commissions + matrix.getValue(0.0, 0.0);
480 
481  Log::instance()->info( "Commission error: %7.2f%% (%d/%d)\n", matrix.getCommissionError() * 100, commissions, total );
482  }
483  }
484 
485  if ( calc_roc ) {
486 
487  if ( num_presences ) {
488 
489  Log::instance()->info( "AUC: %7.2f\n", roc_curve.getTotalArea() );
490 
491  if ( max_omission < 1.0 ) {
492 
493  Log::instance()->info( "Ratio: %7.2f\n", roc_curve.getPartialAreaRatio( max_omission ) );
494  }
495  }
496  }
497 
498  ConfigurationPtr output( new ConfigurationImpl("Statistics") );
499 
500  bool no_statistics = true;
501 
502  if ( calc_matrix && matrix.ready() ) {
503 
504  ConfigurationPtr cm_config( matrix.getConfiguration() );
505 
506  output->addSubsection( cm_config );
507 
508  no_statistics = false;
509  }
510 
511  if ( calc_roc && roc_curve.ready() ) {
512 
513  ConfigurationPtr roc_config( roc_curve.getConfiguration() );
514 
515  output->addSubsection( roc_config );
516 
517  no_statistics = false;
518  }
519 
520  if ( no_statistics )
521  {
522  Log::instance()->warn( "No statistics calculated\n" );
523  }
524 
525  std::ostringstream test_output;
526 
527  Configuration::writeXml( output, test_output );
528 
529  std::cerr << flush;
530 
531  // Write test output to file, if requested
532  if ( ! result_file.empty() ) {
533 
534  ofstream file( result_file.c_str() );
535  file << test_output.str();
536  file.close();
537  }
538  else {
539 
540  // Otherwise send it to stdout
541  std::cout << test_output.str().c_str() << endl << flush;
542  }
543 
544  // If user wants to track progress
545  if ( ! progress_file.empty() ) {
546 
547  // Indicate that the job is finished
548  progressFileCallback( 1.0, &prog_data );
549  }
550  }
551  catch ( runtime_error e ) {
552 
553  // If user is tracking progress
554  if ( ! progress_file.empty() ) {
555 
556  // -2 means aborted
557  progressFileCallback( -2.0, &prog_data );
558  }
559 
560  printf( "om_test aborted: %s\n", e.what() );
561  }
562 }
static void loadConfig(const std::string configFile)
Definition: Settings.cpp:100
void reset(Scalar predictionThreshold=CONF_MATRIX_DEFAULT_THRESHOLD, bool ignoreAbsences=false)
void initialize(int resolution=ROC_DEFAULT_RESOLUTION)
Definition: RocCurve.cpp:70
void warn(const char *format,...)
'Warn' level.
Definition: Log.cpp:273
static ConfigurationPtr readXml(char const *filename)
static AlgorithmPtr newAlgorithm(std::string const id)
double getAccuracy() const
static Log * instance()
Returns the instance pointer, creating the object on the first call.
Definition: Log.cpp:45
std::string file_name
Definition: om_cmd_utils.hh:42
int getValue(Scalar predictionValue, Scalar actualValue) const
static OccurrencesFactory & instance()
Log::Level getLogLevel(std::string level)
virtual OccurrencesPtr getAbsences(const char *groupId)
Level
Definition: Log.hh:54
void setLevel(Level level)
Definition: Log.hh:107
double getCommissionError() const
SamplerPtr createSampler(const EnvironmentPtr &env, const OccurrencesPtr &presence, const OccurrencesPtr &absence)
Definition: Sampler.cpp:52
void error(const char *format,...)
'Error' level.
Definition: Log.cpp:290
void setupExternalResources()
Definition: os_specific.cpp:95
virtual OccurrencesPtr getPresences(const char *groupId)
void calculate(const EnvironmentPtr &env, const Model &model, const OccurrencesPtr &presences, const OccurrencesPtr &absences=OccurrencesPtr())
int main(int argc, char **argv)
Main code.
Definition: om_test.cpp:19
#define UNUSED(symbol)
Definition: os_specific.hh:55
bool ready() const
Definition: RocCurve.hh:144
void set(Level level, std::string fileName, char const *pref="")
Definition: Log.cpp:196
static void writeXml(const ConstConfigurationPtr &config, char const *fileaname)
static int searchDefaultDirs()
void setLowestTrainingThreshold(const Model &model, const SamplerPtr &sampler)
std::string getVersion()
double getOmissionError() const
bool ready() const
double getThreshold() const
#define ROC_DEFAULT_RESOLUTION
Definition: RocCurve.hh:40
void info(const char *format,...)
'Info' level.
Definition: Log.cpp:256
void calculate(const Model &model, const SamplerPtr &sampler)
Definition: RocCurve.cpp:131
ConfigurationPtr getConfiguration() const
Definition: RocCurve.cpp:739
ConfigurationPtr getConfiguration() const
#define CONF_MATRIX_DEFAULT_THRESHOLD
double getTotalArea()
Definition: RocCurve.cpp:602
void debug(const char *format,...)
'Debug' level.
Definition: Log.cpp:237
double getPartialAreaRatio(double e=1.0)
Definition: RocCurve.cpp:615
time_t timestamp
Definition: om_cmd_utils.hh:43
void progressFileCallback(float progress, void *progressData)
OccurrencesReader * create(const char *source, const char *coordSystem)