LinearCombinator.cc
1 /*
2  * MoMEMta: a modular implementation of the Matrix Element Method
3  * Copyright (C) 2016 Universite catholique de Louvain (UCL), Belgium
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  */
18 
19 #include <momemta/ParameterSet.h>
20 #include <momemta/Module.h>
21 #include <momemta/Types.h>
22 #include <momemta/Utils.h>
23 
24 #include <stdexcept>
25 
55 template<typename T>
56 class LinearCombinator: public Module {
57  public:
58 
59  LinearCombinator(PoolPtr pool, const ParameterSet& parameters): Module(pool, parameters.getModuleName()) {
60  auto tags = parameters.get<std::vector<InputTag>>("inputs");
61  for (auto& v: tags)
62  m_terms.push_back(get<T>(v));
63 
64  m_coefficients = parameters.get<std::vector<double>>("coefficients");
65 
66  if (m_coefficients.size() == 0 || m_terms.size() == 0){
67  auto exception = std::invalid_argument("Tried to call LinearCombinator with an empty input.");
68  LOG(fatal) << exception.what();
69  throw exception;
70  }
71  if (m_coefficients.size() != m_terms.size()){
72  auto exception = std::invalid_argument("The Term and Coefficient lists passed to LinearCombinator have different sizes.");
73  LOG(fatal) << exception.what();
74  throw exception;
75  }
76  };
77 
78  virtual Status work() override {
79 
80  T temp_result = m_coefficients[0] * *m_terms[0];
81  for (std::size_t i = 1; i < m_terms.size(); i++)
82  temp_result += m_coefficients[i] * *m_terms[i];
83 
84  *output = temp_result;
85 
86  return Status::OK;
87  }
88 
89  private:
90 
91  // Inputs
92  std::vector<double> m_coefficients;
93  std::vector<Value<T>> m_terms;
94 
95  // Outputs
96  std::shared_ptr<T> output = produce<T>("output");
97 };
98 
99 REGISTER_MODULE_NAME("VectorLinearCombinator", LinearCombinator<LorentzVector>)
100  .Inputs("inputs")
101  .Output("output")
102  .Attr("coefficients:list(double)");
103 
104 REGISTER_MODULE_NAME("DoubleLinearCombinator", LinearCombinator<double>)
105  .Inputs("inputs")
106  .Output("output")
107  .Attr("coefficients:list(double)");
108 
109 REGISTER_MODULE_NAME("IntLinearCombinator", LinearCombinator<int64_t>)
110  .Inputs("inputs")
111  .Output("output")
112  .Attr("coefficients:list(double)");
Performs linear combination of templated terms.
Parent class for all the modules.
Definition: Module.h:37
A class encapsulating a lua table.
Definition: ParameterSet.h:82
virtual Status work() override
Main function.
Module(PoolPtr pool, const std::string &name)
Constructor.
Definition: Module.h:61