source: src/Potentials/PotentialTrainer.cpp@ 82e5fb

Candidate_v1.7.0 stable
Last change on this file since 82e5fb was 82e5fb, checked in by Frederik Heber <frederik.heber@…>, 4 years ago

Added option error-file to fit potential actions.

  • Property mode set to 100644
File size: 10.1 KB
RevLine 
[98d166]1/*
2 * Project: MoleCuilder
3 * Description: creates and alters molecular systems
4 * Copyright (C) 2014 Frederik Heber. All rights reserved.
5 *
6 *
7 * This file is part of MoleCuilder.
8 *
9 * MoleCuilder is free software: you can redistribute it and/or modify
10 * it under the terms of the GNU General Public License as published by
11 * the Free Software Foundation, either version 2 of the License, or
12 * (at your option) any later version.
13 *
14 * MoleCuilder is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 * GNU General Public License for more details.
18 *
19 * You should have received a copy of the GNU General Public License
20 * along with MoleCuilder. If not, see <http://www.gnu.org/licenses/>.
21 */
22
23/*
24 * PotentialTrainer.cpp
25 *
26 * Created on: Sep 11, 2014
27 * Author: heber
28 */
29
30// include config.h
31#ifdef HAVE_CONFIG_H
32#include <config.h>
33#endif
34
35// needs to come before MemDebug due to placement new
36#include <boost/archive/text_iarchive.hpp>
37
[9eb71b3]38//#include "CodePatterns/MemDebug.hpp"
[98d166]39
40#include "PotentialTrainer.hpp"
41
42#include <algorithm>
43#include <boost/lambda/lambda.hpp>
44#include <boost/filesystem.hpp>
45#include <fstream>
46#include <sstream>
47
48#include "CodePatterns/Assert.hpp"
49#include "CodePatterns/Log.hpp"
50
51#include "Element/element.hpp"
52#include "Fragmentation/Homology/HomologyContainer.hpp"
53#include "Fragmentation/Homology/HomologyGraph.hpp"
54#include "FunctionApproximation/Extractors.hpp"
55#include "FunctionApproximation/FunctionApproximation.hpp"
56#include "FunctionApproximation/FunctionModel.hpp"
57#include "FunctionApproximation/TrainingData.hpp"
58#include "FunctionApproximation/writeDistanceEnergyTable.hpp"
59#include "Potentials/CompoundPotential.hpp"
[fde8e7]60#include "Potentials/RegistrySerializer.hpp"
[98d166]61#include "Potentials/SerializablePotential.hpp"
62
63PotentialTrainer::PotentialTrainer()
64{}
65
66PotentialTrainer::~PotentialTrainer()
67{}
68
69bool PotentialTrainer::operator()(
70 const HomologyContainer &_homologies,
71 const HomologyGraph &_graph,
72 const boost::filesystem::path &_trainingfile,
[82e5fb]73 const boost::filesystem::path &_errorfile,
[b40690]74 const unsigned int _maxiterations,
[98d166]75 const double _threshold,
76 const unsigned int _best_of_howmany) const
77{
78 // fit potential
[3400bb]79 CompoundPotential compound(_graph);
80 FunctionModel &model = assert_cast<FunctionModel &>(compound);
[98d166]81
[3400bb]82 if (compound.begin() == compound.end()) {
83 ELOG(1, "Could not find any suitable potentials for the compound potential.");
84 return false;
[29ce5f]85 }
86
[98d166]87 /******************** TRAINING ********************/
88 // fit potential
[3400bb]89 FunctionModel::parameters_t bestparams(model.getParameterDimension(), 0.);
[98d166]90 {
91 // Afterwards we go through all of this type and gather the distance and the energy value
[3400bb]92 TrainingData data(model.getSpecificFilter());
[98d166]93 data(_homologies.getHomologousGraphs(_graph));
94
[d33f24]95 // check data
96 const TrainingData::FilteredInputVector_t &inputs = data.getTrainingInputs();
97 for (TrainingData::FilteredInputVector_t::const_iterator iter = inputs.begin();
98 iter != inputs.end(); ++iter)
99 if (((*iter).empty()) || ((*iter).front().empty())) {
100 ELOG(1, "At least one of the training inputs is empty! Correct fragment and potential charges selected?");
101 return false;
102 }
103 const TrainingData::OutputVector_t &outputs = data.getTrainingOutputs();
104 for (TrainingData::OutputVector_t::const_iterator iter = outputs.begin();
105 iter != outputs.end(); ++iter)
106 if ((*iter).empty()) {
107 ELOG(1, "At least one of the training outputs is empty! Correct fragment and potential charges selected?");
108 return false;
109 }
110
[98d166]111 // print distances and energies if desired for debugging
112 if (!data.getTrainingInputs().empty()) {
113 // print which distance is which
114 size_t counter=1;
115 if (DoLog(3)) {
116 const FunctionModel::arguments_t &inputs = data.getAllArguments()[0];
117 for (FunctionModel::arguments_t::const_iterator iter = inputs.begin();
118 iter != inputs.end(); ++iter) {
119 const argument_t &arg = *iter;
120 LOG(3, "DEBUG: distance " << counter++ << " is between (#"
121 << arg.indices.first << "c" << arg.types.first << ","
122 << arg.indices.second << "c" << arg.types.second << ").");
123 }
124 }
125
126 // print table
127 if (_trainingfile.string().empty()) {
128 LOG(3, "DEBUG: I gathered the following training data:\n" <<
129 _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable()));
130 } else {
131 std::ofstream trainingstream(_trainingfile.string().c_str());
132 if (trainingstream.good()) {
133 LOG(3, "DEBUG: Writing training data to file " <<
134 _trainingfile.string() << ".");
135 trainingstream << _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable());
136 }
137 trainingstream.close();
138 }
139 }
140
141 if ((_threshold < 1.) && (_best_of_howmany))
142 ELOG(2, "threshold parameter always overrules max_runs, both are specified.");
143 // now perform the function approximation by optimizing the model function
[3400bb]144 FunctionApproximation approximator(data, model, _threshold, _maxiterations);
145 if (model.isBoxConstraint() && approximator.checkParameterDerivatives()) {
[98d166]146 double l2error = std::numeric_limits<double>::max();
147 // seed with current time
148 srand((unsigned)time(0));
149 unsigned int runs=0;
150 // threshold overrules max_runs
151 const double threshold = _threshold;
[20fc6f]152 const unsigned int max_runs = (threshold >= 1.) ? _best_of_howmany : std::numeric_limits<unsigned int>::max();
[98d166]153 LOG(1, "INFO: Maximum runs is " << max_runs << " and threshold set to " << threshold << ".");
154 do {
155 // generate new random initial parameter values
[3400bb]156 model.setParametersToRandomInitialValues(data);
[98d166]157 LOG(1, "INFO: Initial parameters of run " << runs << " are "
[3400bb]158 << model.getParameters() << ".");
[98d166]159 approximator(FunctionApproximation::ParameterDerivative);
160 LOG(1, "INFO: Final parameters of run " << runs << " are "
[3400bb]161 << model.getParameters() << ".");
162 const double new_l2error = data.getL2Error(model);
[98d166]163 if (new_l2error < l2error) {
164 // store currently best parameters
165 l2error = new_l2error;
[3400bb]166 bestparams = model.getParameters();
[98d166]167 LOG(1, "STATUS: New fit from run " << runs
168 << " has better error of " << l2error << ".");
169 }
[20fc6f]170 } while (( ++runs < max_runs) && (l2error > threshold));
[98d166]171 // reset parameters from best fit
[3400bb]172 model.setParameters(bestparams);
[98d166]173 LOG(1, "INFO: Best parameters with L2 error of "
[3400bb]174 << l2error << " are " << model.getParameters() << ".");
[98d166]175 } else {
176 return false;
177 }
178
179 // create a map of each fragment with error.
180 HomologyContainer::range_t fragmentrange = _homologies.getHomologousGraphs(_graph);
181 TrainingData::L2ErrorConfigurationIndexMap_t WorseFragmentMap =
[3400bb]182 data.getWorstFragmentMap(model, fragmentrange);
[82e5fb]183 if (_errorfile.string().empty()) {
184 LOG(0, "RESULT: WorstFragmentMap " << WorseFragmentMap << ".");
185 } else {
186 std::ofstream errorstream(_errorfile.string().c_str());
187 if (errorstream.good()) {
188 LOG(3, "DEBUG: Writing error data to file " <<
189 _errorfile.string() << ".");
190 errorstream << "step\terror" << std::endl;
191 // resort into step as key
192 typedef std::map< size_t, double > step_error_t;
193 step_error_t step_error;
194 for (TrainingData::L2ErrorConfigurationIndexMap_t::const_reverse_iterator iter = WorseFragmentMap.rbegin();
195 iter != WorseFragmentMap.rend(); ++iter)
196 step_error.insert( std::make_pair(iter->second, iter->first) );
197 for (step_error_t::const_iterator iter = step_error.begin();
198 iter != step_error.end(); ++iter)
199 errorstream << iter->first << "\t" << iter->second << std::endl;
200 }
201 errorstream.close();
202 }
[98d166]203 }
204
205 return true;
206}
207
208HomologyGraph PotentialTrainer::getFirstGraphwithSpecifiedElements(
209 const HomologyContainer &homologies,
210 const SerializablePotential::ParticleTypes_t &types)
211{
212 ASSERT( !types.empty(),
213 "getFirstGraphwithSpecifiedElements() - charges is empty?");
[c5e75f3]214
[98d166]215 // convert into count map
[c5e75f3]216 Extractors::elementcounts_t counts_per_element =
217 Extractors::_detail::getElementCounts(types);
218 ASSERT( !counts_per_element.empty(),
219 "getFirstGraphwithSpecifiedElements() - element counts are empty?");
220 LOG(1, "DEBUG: counts_per_element is " << counts_per_element << ".");
[98d166]221 // we want to check each (unique) key only once
222 HomologyContainer::const_key_iterator olditer = homologies.key_end();
223 for (HomologyContainer::const_key_iterator iter =
[e63edb]224 homologies.key_begin(); iter != homologies.key_end();
225 iter = homologies.getNextKey(iter)) {
[98d166]226 // if it's the same as the old one, skip it
[e63edb]227 if (olditer == iter)
[98d166]228 continue;
[e63edb]229 else
230 olditer = iter;
[945797]231 // check whether we have the same set of atomic numbers
232 const HomologyGraph::nodes_t &nodes = (*iter).getNodes();
[c5e75f3]233 Extractors::elementcounts_t nodes_counts_per_element;
[945797]234 for (HomologyGraph::nodes_t::const_iterator nodeiter = nodes.begin();
235 nodeiter != nodes.end(); ++nodeiter) {
236 const Extractors::element_t elem = nodeiter->first.getAtomicNumber();
237 const std::pair<Extractors::elementcounts_t::iterator, bool> inserter =
[c5e75f3]238 nodes_counts_per_element.insert( std::make_pair(elem, (Extractors::count_t)nodeiter->second ) );
[945797]239 if (!inserter.second)
240 inserter.first->second += (Extractors::count_t)nodeiter->second;
241 }
[c5e75f3]242 LOG(1, "DEBUG: Node (" << *iter << ")'s counts_per_element is " << nodes_counts_per_element << ".");
243 if (counts_per_element == nodes_counts_per_element)
[98d166]244 return *iter;
245 }
246 return HomologyGraph();
247}
248
249SerializablePotential::ParticleTypes_t PotentialTrainer::getNumbersFromElements(
250 const std::vector<const element *> &fragment)
251{
252 SerializablePotential::ParticleTypes_t fragmentnumbers;
253 std::transform(fragment.begin(), fragment.end(), std::back_inserter(fragmentnumbers),
254 boost::bind(&element::getAtomicNumber, _1));
255 return fragmentnumbers;
256}
Note: See TracBrowser for help on using the repository browser.