Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/Actions/PotentialAction/FitPotentialAction.cpp

    r0ea063 r707a2b  
    5555#include "Fragmentation/Homology/HomologyGraph.hpp"
    5656#include "Fragmentation/Summation/SetValues/Fragment.hpp"
    57 #include "Potentials/EmpiricalPotential.hpp"
     57#include "FunctionApproximation/Extractors.hpp"
     58#include "FunctionApproximation/FunctionApproximation.hpp"
     59#include "FunctionApproximation/FunctionModel.hpp"
     60#include "FunctionApproximation/TrainingData.hpp"
     61#include "FunctionApproximation/writeDistanceEnergyTable.hpp"
     62#include "Potentials/CompoundPotential.hpp"
     63#include "Potentials/Exceptions.hpp"
     64#include "Potentials/PotentialDeserializer.hpp"
    5865#include "Potentials/PotentialFactory.hpp"
    5966#include "Potentials/PotentialRegistry.hpp"
    6067#include "Potentials/PotentialSerializer.hpp"
    61 #include "Potentials/PotentialTrainer.hpp"
    6268#include "Potentials/SerializablePotential.hpp"
    63 #include "World.hpp"
    6469
    6570using namespace MoleCuilder;
     
    7075/** =========== define the function ====================== */
    7176
     77HomologyGraph getFirstGraphwithSpecifiedElements(
     78    const HomologyContainer &homologies,
     79    const SerializablePotential::ParticleTypes_t &types)
     80{
     81  ASSERT( !types.empty(),
     82      "getFirstGraphwithSpecifiedElements() - charges is empty?");
     83  // create charges
     84  Fragment::charges_t charges;
     85  charges.resize(types.size());
     86  std::transform(types.begin(), types.end(),
     87      charges.begin(), boost::lambda::_1);
     88  // convert into count map
     89  Extractors::elementcounts_t counts_per_charge =
     90      Extractors::_detail::getElementCounts(charges);
     91  ASSERT( !counts_per_charge.empty(),
     92      "getFirstGraphwithSpecifiedElements() - charge counts are empty?");
     93  LOG(2, "DEBUG: counts_per_charge is " << counts_per_charge << ".");
     94  // we want to check each (unique) key only once
     95  for (HomologyContainer::const_key_iterator iter = homologies.key_begin();
     96      iter != homologies.key_end(); iter = homologies.getNextKey(iter)) {
     97    // check if every element has the right number of counts
     98    Extractors::elementcounts_t::const_iterator countiter = counts_per_charge.begin();
     99    for (; countiter != counts_per_charge.end(); ++countiter)
     100      if (!(*iter).hasTimesAtomicNumber(
     101          static_cast<size_t>(countiter->first),
     102          static_cast<size_t>(countiter->second))
     103          )
     104        break;
     105    if( countiter == counts_per_charge.end())
     106      return *iter;
     107  }
     108  return HomologyGraph();
     109}
     110
     111SerializablePotential::ParticleTypes_t getNumbersFromElements(
     112    const std::vector<const element *> &fragment)
     113{
     114  SerializablePotential::ParticleTypes_t fragmentnumbers;
     115  std::transform(fragment.begin(), fragment.end(), std::back_inserter(fragmentnumbers),
     116      boost::bind(&element::getAtomicNumber, _1));
     117  return fragmentnumbers;
     118}
     119
     120
    72121ActionState::ptr PotentialFitPotentialAction::performCall() {
    73122  // fragment specifies the homology fragment to use
    74123  SerializablePotential::ParticleTypes_t fragmentnumbers =
    75       PotentialTrainer::getNumbersFromElements(params.fragment.get());
     124      getNumbersFromElements(params.fragment.get());
    76125
    77126  // either charges and a potential is specified or a file
    78   if (params.charges.get().empty()) {
    79     STATUS("No charges given!");
    80     return Action::failure;
     127  if (boost::filesystem::exists(params.potential_file.get())) {
     128    std::ifstream returnstream(params.potential_file.get().string().c_str());
     129    if (returnstream.good()) {
     130      try {
     131        PotentialDeserializer deserialize(returnstream);
     132        deserialize();
     133      } catch (SerializablePotentialMissingValueException &e) {
     134        if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e))
     135          STATUS("Missing value when parsing information for potential "+*key+".");
     136        else
     137          STATUS("Missing value parsing information for potential with unknown key.");
     138        return Action::failure;
     139      } catch (SerializablePotentialIllegalKeyException &e) {
     140        if (const std::string *key = boost::get_error_info<SerializablePotentialKey>(e))
     141          STATUS("Illegal key parsing information for potential "+*key+".");
     142        else
     143          STATUS("Illegal key parsing information for potential with unknown key.");
     144        return Action::failure;
     145      }
     146    } else {
     147      STATUS("Failed to parse from "+params.potential_file.get().string()+".");
     148      return Action::failure;
     149    }
     150    returnstream.close();
     151
     152    LOG(0, "STATUS: I'm training now a set of potentials parsed from "
     153        << params.potential_file.get().string() << " on a fragment "
     154        << fragmentnumbers << " on data from World's homologies.");
     155
    81156  } else {
    82     // charges specify the potential type
    83     SerializablePotential::ParticleTypes_t chargenumbers =
    84         PotentialTrainer::getNumbersFromElements(params.charges.get());
    85 
    86     LOG(0, "STATUS: I'm training now a " << params.potentialtype.get()
    87         << " potential on charges " << chargenumbers << " on data from World's homologies.");
    88 
    89     // register desired potential and an additional constant one
    90     {
    91       EmpiricalPotential *potential =
    92           PotentialFactory::getInstance().createInstance(
    93               params.potentialtype.get(),
    94               chargenumbers);
    95       // check whether such a potential already exists
    96       const std::string potential_name = potential->getName();
    97       if (PotentialRegistry::getInstance().isPresentByName(potential_name)) {
    98         delete potential;
    99         potential = PotentialRegistry::getInstance().getByName(potential_name);
    100       } else
    101         PotentialRegistry::getInstance().registerInstance(potential);
    102     }
    103     {
    104       EmpiricalPotential *constant =
    105           PotentialFactory::getInstance().createInstance(
    106               std::string("constant"),
    107               SerializablePotential::ParticleTypes_t());
    108       // check whether such a potential already exists
    109       const std::string constant_name = constant->getName();
    110       if (PotentialRegistry::getInstance().isPresentByName(constant_name)) {
    111         delete constant;
    112         constant = PotentialRegistry::getInstance().getByName(constant_name);
    113       } else
    114         PotentialRegistry::getInstance().registerInstance(constant);
     157    if (params.charges.get().empty()) {
     158      STATUS("Neither charges nor potential file given!");
     159      return Action::failure;
     160    } else {
     161      // charges specify the potential type
     162      SerializablePotential::ParticleTypes_t chargenumbers =
     163          getNumbersFromElements(params.charges.get());
     164
     165      LOG(0, "STATUS: I'm training now a " << params.potentialtype.get()
     166          << " potential on charges " << chargenumbers << " on data from World's homologies.");
     167
     168      // register desired potential and an additional constant one
     169      {
     170        EmpiricalPotential *potential =
     171            PotentialFactory::getInstance().createInstance(
     172                params.potentialtype.get(),
     173                chargenumbers);
     174        // check whether such a potential already exists
     175        const std::string potential_name = potential->getName();
     176        if (PotentialRegistry::getInstance().isPresentByName(potential_name)) {
     177          delete potential;
     178          potential = PotentialRegistry::getInstance().getByName(potential_name);
     179        } else
     180          PotentialRegistry::getInstance().registerInstance(potential);
     181      }
     182      {
     183        EmpiricalPotential *constant =
     184            PotentialFactory::getInstance().createInstance(
     185                std::string("constant"),
     186                SerializablePotential::ParticleTypes_t());
     187        // check whether such a potential already exists
     188        const std::string constant_name = constant->getName();
     189        if (PotentialRegistry::getInstance().isPresentByName(constant_name)) {
     190          delete constant;
     191          constant = PotentialRegistry::getInstance().getByName(constant_name);
     192        } else
     193          PotentialRegistry::getInstance().registerInstance(constant);
     194      }
    115195    }
    116196  }
    117197
    118198  // parse homologies into container
    119   const HomologyContainer &homologies = World::getInstance().getHomologies();
     199  HomologyContainer &homologies = World::getInstance().getHomologies();
    120200
    121201  // first we try to look into the HomologyContainer
     
    131211
    132212  // then we ought to pick the right HomologyGraph ...
    133   const HomologyGraph graph =
    134       PotentialTrainer::getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers);
     213  const HomologyGraph graph = getFirstGraphwithSpecifiedElements(homologies,fragmentnumbers);
    135214  if (graph != HomologyGraph()) {
    136215    LOG(1, "First representative graph containing fragment "
     
    141220  }
    142221
    143   // training
    144   PotentialTrainer trainer;
    145   const bool status = trainer(
    146       homologies,
    147       graph,
    148       params.training_file.get(),
    149       params.threshold.get(),
    150       params.best_of_howmany.get());
    151   if (!status) {
    152     STATUS("No required parameter derivatives for a box constraint minimization known.");
    153     return Action::failure;
    154   }
     222  // fit potential
     223  FunctionModel *model = new CompoundPotential(graph);
     224  ASSERT( model != NULL,
     225      "PotentialFitPotentialAction::performCall() - model is NULL.");
     226
     227  /******************** TRAINING ********************/
     228  // fit potential
     229  FunctionModel::parameters_t bestparams(model->getParameterDimension(), 0.);
     230  {
     231    // Afterwards we go through all of this type and gather the distance and the energy value
     232    TrainingData data(model->getSpecificFilter());
     233    data(homologies.getHomologousGraphs(graph));
     234
     235    // print distances and energies if desired for debugging
     236    if (!data.getTrainingInputs().empty()) {
     237      // print which distance is which
     238      size_t counter=1;
     239      if (DoLog(3)) {
     240        const FunctionModel::arguments_t &inputs = data.getAllArguments()[0];
     241        for (FunctionModel::arguments_t::const_iterator iter = inputs.begin();
     242            iter != inputs.end(); ++iter) {
     243          const argument_t &arg = *iter;
     244          LOG(3, "DEBUG: distance " << counter++ << " is between (#"
     245              << arg.indices.first << "c" << arg.types.first << ","
     246              << arg.indices.second << "c" << arg.types.second << ").");
     247        }
     248      }
     249
     250      // print table
     251      if (params.training_file.get().string().empty()) {
     252        LOG(3, "DEBUG: I gathered the following training data:\n" <<
     253            _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable()));
     254      } else {
     255        std::ofstream trainingstream(params.training_file.get().string().c_str());
     256        if (trainingstream.good()) {
     257          LOG(3, "DEBUG: Writing training data to file " <<
     258              params.training_file.get().string() << ".");
     259          trainingstream << _detail::writeDistanceEnergyTable(data.getDistanceEnergyTable());
     260        }
     261        trainingstream.close();
     262      }
     263    }
     264
     265    if ((params.threshold.get() < 1) && (params.best_of_howmany.isSet()))
     266      ELOG(2, "threshold parameter always overrules max_runs, both are specified.");
     267    // now perform the function approximation by optimizing the model function
     268    FunctionApproximation approximator(data, *model);
     269    if (model->isBoxConstraint() && approximator.checkParameterDerivatives()) {
     270      double l2error = std::numeric_limits<double>::max();
     271      // seed with current time
     272      srand((unsigned)time(0));
     273      unsigned int runs=0;
     274      // threshold overrules max_runs
     275      const double threshold = params.threshold.get();
     276      const unsigned int max_runs = (threshold >= 1.) ?
     277          (params.best_of_howmany.isSet() ? params.best_of_howmany.get() : 1) : 0;
     278      LOG(1, "INFO: Maximum runs is " << max_runs << " and threshold set to " << threshold << ".");
     279      do {
     280        // generate new random initial parameter values
     281        model->setParametersToRandomInitialValues(data);
     282        LOG(1, "INFO: Initial parameters of run " << runs << " are "
     283            << model->getParameters() << ".");
     284        approximator(FunctionApproximation::ParameterDerivative);
     285        LOG(1, "INFO: Final parameters of run " << runs << " are "
     286            << model->getParameters() << ".");
     287        const double new_l2error = data.getL2Error(*model);
     288        if (new_l2error < l2error) {
     289          // store currently best parameters
     290          l2error = new_l2error;
     291          bestparams = model->getParameters();
     292          LOG(1, "STATUS: New fit from run " << runs
     293              << " has better error of " << l2error << ".");
     294        }
     295      } while (( ++runs < max_runs) || (l2error > threshold));
     296      // reset parameters from best fit
     297      model->setParameters(bestparams);
     298      LOG(1, "INFO: Best parameters with L2 error of "
     299          << l2error << " are " << model->getParameters() << ".");
     300    } else {
     301      STATUS("No required parameter derivatives for a box constraint minimization known.");
     302      return Action::failure;
     303    }
     304
     305    // create a map of each fragment with error.
     306    HomologyContainer::range_t fragmentrange = homologies.getHomologousGraphs(graph);
     307    TrainingData::L2ErrorConfigurationIndexMap_t WorseFragmentMap =
     308        data.getWorstFragmentMap(*model, fragmentrange);
     309    LOG(0, "RESULT: WorstFragmentMap " << WorseFragmentMap << ".");
     310
     311    // print fitted potentials
     312    std::stringstream potentials;
     313    PotentialSerializer serialize(potentials);
     314    serialize();
     315    LOG(1, "STATUS: Resulting parameters are " << std::endl << potentials.str());
     316    std::ofstream returnstream(params.potential_file.get().string().c_str());
     317    if (returnstream.good()) {
     318      returnstream << potentials.str();
     319    }
     320  }
     321  delete model;
    155322
    156323  return Action::success;
Note: See TracChangeset for help on using the changeset viewer.