- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
src/Actions/PotentialAction/FitPotentialAction.cpp
r0ea063 r707a2b 55 55 #include "Fragmentation/Homology/HomologyGraph.hpp" 56 56 #include "Fragmentation/Summation/SetValues/Fragment.hpp" 57 #include "Potentials/EmpiricalPotential.hpp" 57 #include "FunctionApproximation/Extractors.hpp" 58 #include "FunctionApproximation/FunctionApproximation.hpp" 59 #include "FunctionApproximation/FunctionModel.hpp" 60 #include "FunctionApproximation/TrainingData.hpp" 61 #include "FunctionApproximation/writeDistanceEnergyTable.hpp" 62 #include "Potentials/CompoundPotential.hpp" 63 #include "Potentials/Exceptions.hpp" 64 #include "Potentials/PotentialDeserializer.hpp" 58 65 #include "Potentials/PotentialFactory.hpp" 59 66 #include "Potentials/PotentialRegistry.hpp" 60 67 #include "Potentials/PotentialSerializer.hpp" 61 #include "Potentials/PotentialTrainer.hpp"62 68 #include "Potentials/SerializablePotential.hpp" 63 #include "World.hpp"64 69 65 70 using namespace MoleCuilder; … … 70 75 /** =========== define the function ====================== */ 71 76 77 HomologyGraph getFirstGraphwithSpecifiedElements( 78 const HomologyContainer &homologies, 79 const SerializablePotential::ParticleTypes_t &types) 80 { 81 ASSERT( !types.empty(), 82 "getFirstGraphwithSpecifiedElements() - charges is empty?"); 83 // create charges 84 Fragment::charges_t charges; 85 charges.resize(types.size()); 86 std::transform(types.begin(), types.end(), 87 charges.begin(), boost::lambda::_1); 88 // convert into count map 89 Extractors::elementcounts_t counts_per_charge = 90 Extractors::_detail::getElementCounts(charges); 91 ASSERT( !counts_per_charge.empty(), 92 "getFirstGraphwithSpecifiedElements() - charge counts are empty?"); 93 LOG(2, "DEBUG: counts_per_charge is " << counts_per_charge << "."); 94 // we want to check each (unique) key only once 95 for (HomologyContainer::const_key_iterator iter = homologies.key_begin(); 96 iter != homologies.key_end(); iter = homologies.getNextKey(iter)) { 97 // check if every element has the right number of counts 98 Extractors::elementcounts_t::const_iterator countiter = counts_per_charge.begin(); 99 for (; countiter != counts_per_charge.end(); ++countiter) 100 if (!(*iter).hasTimesAtomicNumber( 101 static_cast<size_t>(countiter->first), 102 static_cast<size_t>(countiter->second)) 103 ) 104 break; 105 if( countiter == counts_per_charge.end()) 106 return *iter; 107 } 108 return HomologyGraph(); 109 } 110 111 SerializablePotential::ParticleTypes_t getNumbersFromElements( 112 const std::vector<const element *> &fragment) 113 { 114 SerializablePotential::ParticleTypes_t fragmentnumbers; 115 std::transform(fragment.begin(), fragment.end(), std::back_inserter(fragmentnumbers), 116 boost::bind(&element::getAtomicNumber, _1)); 117 return fragmentnumbers; 118 } 119 120 72 121 ActionState::ptr PotentialFitPotentialAction::performCall() { 73 122 // fragment specifies the homology fragment to use 74 123 SerializablePotential::ParticleTypes_t fragmentnumbers = 75 PotentialTrainer::getNumbersFromElements(params.fragment.get());124 getNumbersFromElements(params.fragment.get()); 76 125 77 126 // either charges and a potential is specified or a file 78 if (params.charges.get().empty()) { 79 STATUS("No charges given!"); 80 return Action::failure; 127 if (boost::filesystem::exists(params.potential_file.get())) { 128 std::ifstream returnstream(params.potential_file.get().string().c_str()); 129 if (returnstream.good()) { 130 try { 131 PotentialDeserializer deserialize(returnstream); 132 deserialize(); 133 } catch (SerializablePotentialMissingValueException &e) { 134 if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e)) 135 STATUS("Missing value when parsing information for potential "+*key+"."); 136 else 137 STATUS("Missing value parsing information for potential with unknown key."); 138 return Action::failure; 139 } catch (SerializablePotentialIllegalKeyException &e) { 140 if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e)) 141 STATUS("Illegal key parsing information for potential "+*key+"."); 142 else 143 STATUS("Illegal key parsing information for potential with unknown key."); 144 return Action::failure; 145 } 146 } else { 147 STATUS("Failed to parse from "+params.potential_file.get().string()+"."); 148 return Action::failure; 149 } 150 returnstream.close(); 151 152 LOG(0, "STATUS: I'm training now a set of potentials parsed from " 153 << params.potential_file.get().string() << " on a fragment " 154 << fragmentnumbers << " on data from World's homologies."); 155 81 156 } else { 82 // charges specify the potential type 83 SerializablePotential::ParticleTypes_t chargenumbers = 84 PotentialTrainer::getNumbersFromElements(params.charges.get()); 85 86 LOG(0, "STATUS: I'm training now a " << params.potentialtype.get() 87 << " potential on charges " << chargenumbers << " on data from World's homologies."); 88 89 // register desired potential and an additional constant one 90 { 91 EmpiricalPotential *potential = 92 PotentialFactory::getInstance().createInstance( 93 params.potentialtype.get(), 94 chargenumbers); 95 // check whether such a potential already exists 96 const std::string potential_name = potential->getName(); 97 if (PotentialRegistry::getInstance().isPresentByName(potential_name)) { 98 delete potential; 99 potential = PotentialRegistry::getInstance().getByName(potential_name); 100 } else 101 PotentialRegistry::getInstance().registerInstance(potential); 102 } 103 { 104 EmpiricalPotential *constant = 105 PotentialFactory::getInstance().createInstance( 106 std::string("constant"), 107 SerializablePotential::ParticleTypes_t()); 108 // check whether such a potential already exists 109 const std::string constant_name = constant->getName(); 110 if (PotentialRegistry::getInstance().isPresentByName(constant_name)) { 111 delete constant; 112 constant = PotentialRegistry::getInstance().getByName(constant_name); 113 } else 114 PotentialRegistry::getInstance().registerInstance(constant); 157 if (params.charges.get().empty()) { 158 STATUS("Neither charges nor potential file given!"); 159 return Action::failure; 160 } else { 161 // charges specify the potential type 162 SerializablePotential::ParticleTypes_t chargenumbers = 163 getNumbersFromElements(params.charges.get()); 164 165 LOG(0, "STATUS: I'm training now a " << params.potentialtype.get() 166 << " potential on charges " << chargenumbers << " on data from World's homologies."); 167 168 // register desired potential and an additional constant one 169 { 170 EmpiricalPotential *potential = 171 PotentialFactory::getInstance().createInstance( 172 params.potentialtype.get(), 173 chargenumbers); 174 // check whether such a potential already exists 175 const std::string potential_name = potential->getName(); 176 if (PotentialRegistry::getInstance().isPresentByName(potential_name)) { 177 delete potential; 178 potential = PotentialRegistry::getInstance().getByName(potential_name); 179 } else 180 PotentialRegistry::getInstance().registerInstance(potential); 181 } 182 { 183 EmpiricalPotential *constant = 184 PotentialFactory::getInstance().createInstance( 185 std::string("constant"), 186 SerializablePotential::ParticleTypes_t()); 187 // check whether such a potential already exists 188 const std::string constant_name = constant->getName(); 189 if (PotentialRegistry::getInstance().isPresentByName(constant_name)) { 190 delete constant; 191 constant = PotentialRegistry::getInstance().getByName(constant_name); 192 } else 193 PotentialRegistry::getInstance().registerInstance(constant); 194 } 115 195 } 116 196 } 117 197 118 198 // parse homologies into container 119 constHomologyContainer &homologies = World::getInstance().getHomologies();199 HomologyContainer &homologies = World::getInstance().getHomologies(); 120 200 121 201 // first we try to look into the HomologyContainer … … 131 211 132 212 // then we ought to pick the right HomologyGraph ... 133 const HomologyGraph graph = 134 PotentialTrainer::getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers); 213 const HomologyGraph graph = getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers); 135 214 if (graph != HomologyGraph()) { 136 215 LOG(1, "First representative graph containing fragment " … … 141 220 } 142 221 143 // training 144 PotentialTrainer trainer; 145 const bool status = trainer( 146 homologies, 147 graph, 148 params.training_file.get(), 149 params.threshold.get(), 150 params.best_of_howmany.get()); 151 if (!status) { 152 STATUS("No required parameter derivatives for a box constraint minimization known."); 153 return Action::failure; 154 } 222 // fit potential 223 FunctionModel *model = new CompoundPotential(graph); 224 ASSERT( model != NULL, 225 "PotentialFitPotentialAction::performCall() - model is NULL."); 226 227 /******************** TRAINING ********************/ 228 // fit potential 229 FunctionModel::parameters_t bestparams(model->getParameterDimension(), 0.); 230 { 231 // Afterwards we go through all of this type and gather the distance and the energy value 232 TrainingData data(model->getSpecificFilter()); 233 data(homologies.getHomologousGraphs(graph)); 234 235 // print distances and energies if desired for debugging 236 if (!data.getTrainingInputs().empty()) { 237 // print which distance is which 238 size_t counter=1; 239 if (DoLog(3)) { 240 const FunctionModel::arguments_t &inputs = data.getAllArguments()[0]; 241 for (FunctionModel::arguments_t::const_iterator iter = inputs.begin(); 242 iter != inputs.end(); ++iter) { 243 const argument_t &arg = *iter; 244 LOG(3, "DEBUG: distance " << counter++ << " is between (#" 245 << arg.indices.first << "c" << arg.types.first << "," 246 << arg.indices.second << "c" << arg.types.second << ")."); 247 } 248 } 249 250 // print table 251 if (params.training_file.get().string().empty()) { 252 LOG(3, "DEBUG: I gathered the following training data:\n" << 253 _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable())); 254 } else { 255 std::ofstream trainingstream(params.training_file.get().string().c_str()); 256 if (trainingstream.good()) { 257 LOG(3, "DEBUG: Writing training data to file " << 258 params.training_file.get().string() << "."); 259 trainingstream << _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable()); 260 } 261 trainingstream.close(); 262 } 263 } 264 265 if ((params.threshold.get() < 1) && (params.best_of_howmany.isSet())) 266 ELOG(2, "threshold parameter always overrules max_runs, both are specified."); 267 // now perform the function approximation by optimizing the model function 268 FunctionApproximation approximator(data, *model); 269 if (model->isBoxConstraint() && approximator.checkParameterDerivatives()) { 270 double l2error = std::numeric_limits<double>::max(); 271 // seed with current time 272 srand((unsigned)time(0)); 273 unsigned int runs=0; 274 // threshold overrules max_runs 275 const double threshold = params.threshold.get(); 276 const unsigned int max_runs = (threshold >= 1.) ? 277 (params.best_of_howmany.isSet() ? params.best_of_howmany.get() : 1) : 0; 278 LOG(1, "INFO: Maximum runs is " << max_runs << " and threshold set to " << threshold << "."); 279 do { 280 // generate new random initial parameter values 281 model->setParametersToRandomInitialValues(data); 282 LOG(1, "INFO: Initial parameters of run " << runs << " are " 283 << model->getParameters() << "."); 284 approximator(FunctionApproximation::ParameterDerivative); 285 LOG(1, "INFO: Final parameters of run " << runs << " are " 286 << model->getParameters() << "."); 287 const double new_l2error = data.getL2Error(*model); 288 if (new_l2error < l2error) { 289 // store currently best parameters 290 l2error = new_l2error; 291 bestparams = model->getParameters(); 292 LOG(1, "STATUS: New fit from run " << runs 293 << " has better error of " << l2error << "."); 294 } 295 } while (( ++runs < max_runs) || (l2error > threshold)); 296 // reset parameters from best fit 297 model->setParameters(bestparams); 298 LOG(1, "INFO: Best parameters with L2 error of " 299 << l2error << " are " << model->getParameters() << "."); 300 } else { 301 STATUS("No required parameter derivatives for a box constraint minimization known."); 302 return Action::failure; 303 } 304 305 // create a map of each fragment with error. 306 HomologyContainer::range_t fragmentrange = homologies.getHomologousGraphs(graph); 307 TrainingData::L2ErrorConfigurationIndexMap_t WorseFragmentMap = 308 data.getWorstFragmentMap(*model, fragmentrange); 309 LOG(0, "RESULT: WorstFragmentMap " << WorseFragmentMap << "."); 310 311 // print fitted potentials 312 std::stringstream potentials; 313 PotentialSerializer serialize(potentials); 314 serialize(); 315 LOG(1, "STATUS: Resulting parameters are " << std::endl << potentials.str()); 316 std::ofstream returnstream(params.potential_file.get().string().c_str()); 317 if (returnstream.good()) { 318 returnstream << potentials.str(); 319 } 320 } 321 delete model; 155 322 156 323 return Action::success;
Note:
See TracChangeset
for help on using the changeset viewer.