openModeller
Version 1.4.0
|
00001 00027 #include <openmodeller/om.hh> 00028 #include <openmodeller/Configuration.hh> 00029 #include <openmodeller/os_specific.hh> 00030 00031 #include "request_file.hh" 00032 #include "om_cmd_utils.hh" 00033 00034 #include <istream> 00035 #include <stdlib.h> 00036 #include <string.h> 00037 #include <stdio.h> 00038 #include <string> 00039 #include <stdexcept> 00040 00041 #ifdef MPI_FOUND 00042 #include "mpi.h" 00043 #endif 00044 00045 int showAlgorithms ( AlgMetadata const **availables ); 00046 AlgMetadata const *readAlgorithm( AlgMetadata const **availables ); 00047 int readParameters( AlgParameter *result, AlgMetadata const *metadata ); 00048 char *extractParameter( char *name, int nvet, char **vet ); 00049 00050 void mapCallback( float progress, void *extra_param ); 00051 void modelCallback( float progress, void * extra_param ); 00052 00053 /**************************************************************/ 00054 /*************** openModeller Console Interface ***************/ 00055 00056 /************/ 00057 /*** main ***/ 00058 int 00059 main( int argc, char **argv ) 00060 { 00061 if ( argc < 2 ) { 00062 00063 printf ("Usage: %s request_file [log_level [config_file]]\n", argv[0]); 00064 exit(1); 00065 } 00066 00067 // Reconfigure the global logger 00068 string log_level("info"); 00069 00070 if ( argc > 2 ) { 00071 00072 log_level.assign( argv[2] ); 00073 } 00074 if ( argc > 3 ) { 00075 00076 Settings::loadConfig( argv[3] ); 00077 } 00078 00079 Log::Level level_code = getLogLevel( log_level ); 00080 Log::instance()->setLevel( level_code ); 00081 Log::instance()->setPrefix( "" ); 00082 00083 // Set up any related external resources 00084 setupExternalResources(); 00085 00086 try { 00087 00088 char *request_file = argv[1]; 00089 00090 AlgorithmFactory::searchDefaultDirs(); 00091 OpenModeller om; 00092 Log::instance()->info( "openModeller version %s\n", om.getVersion().c_str() ); 00093 00094 // Configure the OpenModeller object from data read from the 00095 // request file. 00096 RequestFile request; 00097 int resp = request.configure( &om, request_file ); 00098 00099 if ( resp < 0 ) { 00100 00101 Log::instance()->error( "Could not read request file %s", request_file ); 00102 exit(1); 00103 } 00104 00105 // If something was not set... 00106 if ( resp ) { 00107 00108 if ( ! request.occurrencesSet() ) { 00109 00110 exit(1); 00111 } 00112 00113 if ( ! request.algorithmSet() ) { 00114 00115 // Find out which model algorithm is to be used. 00116 AlgMetadata const **availables = om.availableAlgorithms(); 00117 AlgMetadata const *metadata; 00118 00119 if ( ! (metadata = readAlgorithm( availables )) ) { 00120 00121 return 1; 00122 } 00123 00124 Log::instance()->info( "\n> Algorithm used: %s\n\n", metadata->name.c_str() ); 00125 Log::instance()->info( " %s\n\n", metadata->overview.c_str() ); 00126 00127 // For resulting parameters storage. 00128 int nparam = metadata->nparam; 00129 AlgParameter *param = new AlgParameter[nparam]; 00130 00131 // Read from console the parameters not set by request 00132 // file. Fills 'param' with all 'metadata->nparam' 00133 // parameters set. 00134 readParameters( param, metadata ); 00135 00136 // Set the model algorithm to be used by the controller 00137 om.setAlgorithm( metadata->id, nparam, param ); 00138 00139 delete[] param; 00140 delete[] availables; 00141 } 00142 } 00143 00144 #ifdef MPI_FOUND 00145 00146 Log::instance()->info( "Running the parallel version of the algorithm\n" ); 00147 00148 MPI_Init( &argc, &argv ); // MPI initialization 00149 00150 int rank; 00151 00152 MPI_Comm_rank( MPI_COMM_WORLD, &rank ); 00153 00154 #endif 00155 00156 // Run model 00157 om.setModelCallback( modelCallback ); 00158 00159 request.makeModel( &om ); 00160 00161 if ( request.requestedProjection() ) { 00162 00163 // Run projection 00164 om.setMapCallback( mapCallback ); 00165 00166 try { 00167 00168 request.makeProjection( &om ); 00169 } 00170 catch ( ... ) {} 00171 } 00172 else { 00173 00174 Log::instance()->warn( "Skipping projection\n" ); 00175 } 00176 00177 if ( request.calcConfusionMatrix() ) { 00178 00179 // Instantiate objects for model statistics 00180 const ConfusionMatrix * const matrix = om.getConfusionMatrix(); 00181 00182 // Confusion Matrix 00183 Log::instance()->info( "\n" ); 00184 Log::instance()->info( "Model statistics for training data\n" ); 00185 Log::instance()->info( "Threshold: %7.2f%%\n", matrix->getThreshold() * 100 ); 00186 Log::instance()->info( "Accuracy: %7.2f%%\n", matrix->getAccuracy() * 100 ); 00187 00188 int omissions = matrix->getValue(0.0, 1.0); 00189 int total = omissions + matrix->getValue(1.0, 1.0); 00190 00191 Log::instance()->info( "Omission error: %7.2f%% (%d/%d)\n", matrix->getOmissionError() * 100, omissions, total ); 00192 00193 double commissionError = matrix->getCommissionError(); 00194 00195 if ( commissionError >= 0.0 ) { 00196 00197 int commissions = matrix->getValue(1.0, 0.0); 00198 total = commissions + matrix->getValue(0.0, 0.0); 00199 00200 Log::instance()->info( "Commission error: %7.2f%% (%d/%d)\n", commissionError * 100, commissions, total ); 00201 } 00202 00203 ConfusionMatrix auxMatrix; 00204 auxMatrix.setLowestTrainingThreshold( om.getModel(), om.getSampler() ); 00205 Log::instance()->info( "Lowest prediction: %7.2f\n", auxMatrix.getThreshold() ); 00206 } 00207 00208 if ( request.calcAuc() ) { 00209 00210 RocCurve * const roc_curve = om.getRocCurve(); 00211 00212 // ROC curve 00213 Log::instance()->info( "AUC: %7.2f\n", roc_curve->getTotalArea() ); 00214 } 00215 00216 // Projection statistics 00217 if ( request.requestedProjection() ) { 00218 00219 Log::instance()->info( "\n" ); 00220 Log::instance()->info( "Projection statistics\n" ); 00221 00222 AreaStats * stats = om.getActualAreaStats(); 00223 00224 Log::instance()->info( "Threshold: 50%%\n" ); 00225 Log::instance()->info( "Cells predicted present: %7.2f%%\n", 00226 stats->getAreaPredictedPresent() / (double) stats->getTotalArea() * 100 ); 00227 Log::instance()->info( "Total number of cells: %d\n", stats->getTotalArea() ); 00228 Log::instance()->info( "Done.\n" ); 00229 00230 delete stats; 00231 } 00232 } 00233 catch ( std::exception& e ) { 00234 Log::instance()->info( "Exception occurred: %s", e.what() ); 00235 } 00236 catch ( ... ) { 00237 Log::instance()->info( "Unknown error occurred\n" ); 00238 } 00239 00240 #ifdef MPI_FOUND 00241 MPI_Finalize(); 00242 #endif 00243 00244 return 0; 00245 } 00246 00247 00248 /***********************/ 00249 /*** show algorithms ***/ 00250 // 00251 // Print available algorithms. 00252 // Returns the option number associated with 'Quit' that is 00253 // equal to the number of algorithms. 00254 // 00255 int 00256 showAlgorithms( AlgMetadata const **availables ) 00257 { 00258 if ( ! *availables ) 00259 { 00260 printf( "Could not find any algorithms.\n" ); 00261 return 0; 00262 } 00263 00264 printf( "\nChoose an algorithm between:\n" ); 00265 00266 int count = 1; 00267 AlgMetadata const *metadata; 00268 while ( ( metadata = *availables++ ) ) 00269 { 00270 printf( " [%d] %s\n", count++, metadata->name.c_str() ); 00271 } 00272 printf( " [q] Quit\n" ); 00273 printf( "\n" ); 00274 00275 return count; 00276 } 00277 00278 00279 /**********************/ 00280 /*** read algorithm ***/ 00281 // 00282 // Let the user choose an algorithm and enter its parameters. 00283 // Returns the choosed algorithm's metadata. 00284 // 00285 AlgMetadata const * 00286 readAlgorithm( AlgMetadata const **availables ) 00287 { 00288 char buf[128]; 00289 00290 while ( 1 ) { 00291 00292 int quit_option = showAlgorithms( availables ); 00293 if ( ! quit_option ) { 00294 return 0; 00295 } 00296 00297 int option = -1; 00298 00299 printf( "\nOption: " ); 00300 fgets( buf, 128, stdin ); 00301 00302 int first_char_ascii = (int)buf[0]; 00303 00304 // Quit if input is "q" or "Q" 00305 if ( first_char_ascii == 113 || first_char_ascii == 81 ) { 00306 return 0; 00307 } 00308 00309 option = atoi( buf ); 00310 00311 if ( option <= 0 || option >= quit_option ) { 00312 return 0; 00313 } 00314 00315 // An algorithm was choosed. 00316 else { 00317 return availables[option-1]; 00318 } 00319 } 00320 } 00321 00322 00323 /***********************/ 00324 /*** read Parameters ***/ 00325 int 00326 readParameters( AlgParameter *result, AlgMetadata const *metadata ) 00327 { 00328 AlgParamMetadata *param = metadata->param; 00329 AlgParamMetadata *end = param + metadata->nparam; 00330 00331 // Read from stdin each algorithm parameter. 00332 for ( ; param < end; param++, result++ ) 00333 { 00334 // The resulting ID is equal the ID set in algorithm's 00335 // metadata. 00336 result->setId( param->id ); 00337 00338 // Informs the parameter's metadata to the user. 00339 printf( "\n* Parameter: %s\n\n", param->name.c_str() ); 00340 printf( " %s\n", param->overview.c_str() ); 00341 00342 if ( param->type != String ) { 00343 00344 if ( param->has_min ) { 00345 00346 if ( param->type == Integer ) { 00347 00348 printf( "%s >= %d\n", param->name.c_str(), int( param->min_val ) ); 00349 } 00350 else { 00351 00352 printf( " %s >= %f\n", param->name.c_str(), param->min_val ); 00353 } 00354 } 00355 if ( param->has_max ) { 00356 00357 if ( param->type == Integer ) { 00358 00359 printf( "%s <= %d\n\n", param->name.c_str(), int( param->max_val ) ); 00360 } 00361 else { 00362 00363 printf( " %s <= %f\n\n", param->name.c_str(), param->max_val ); 00364 } 00365 } 00366 } 00367 00368 printf( "Enter with value [%s]: ", param->typical.c_str() ); 00369 00370 // Read parameter's value or use the "typical" value 00371 // if the user does not enter a new value. 00372 char value[64]; 00373 *value = 0; 00374 if ( fgets( value, 64, stdin ) && ( *value >= ' ' ) ) { 00375 00376 // Remove line feed to avoid problems with string parameters 00377 if ( param->type == String ) { 00378 00379 char * pos = strchr( value, '\n' ); 00380 00381 if ( pos ) { 00382 00383 *pos = '\0'; 00384 } 00385 } 00386 00387 result->setValue( value ); 00388 } 00389 else { 00390 00391 result->setValue( param->typical ); 00392 } 00393 } 00394 00395 return metadata->nparam; 00396 } 00397 00398 00399 /*************************/ 00400 /*** extract Parameter ***/ 00407 char * 00408 extractParameter( char *id, int nvet, char **vet ) 00409 { 00410 int length = strlen( id ); 00411 char **end = vet + nvet; 00412 00413 while ( vet < end ) 00414 if ( ! strncmp( id, *vet++, length ) ) 00415 return *(vet-1) + length; 00416 00417 return 0; 00418 } 00419 00420 00421 /********************/ 00422 /*** map Callback ***/ 00426 void 00427 modelCallback( float progress, void *extra_param ) 00428 { 00429 Log::instance()->info( "Model creation: %07.4f%%\r", 100 * progress ); 00430 } 00431 00432 00433 /********************/ 00434 /*** map Callback ***/ 00438 void 00439 mapCallback( float progress, void *extra_param ) 00440 { 00441 Log::instance()->info( "Map creation: %07.4f%%\r", 100 * progress ); 00442 }