BlockC.cc
1 /*
2  * MoMEMta: a modular implementation of the Matrix Element Method
3  * Copyright (C) 2017 Universite catholique de Louvain (UCL), Belgium
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  */
18 
19 #include <momemta/Math.h>
20 #include <momemta/Module.h>
21 #include <momemta/ParameterSet.h>
22 #include <momemta/Solution.h>
23 #include <momemta/Types.h>
24 
81 class BlockC : public Module {
82 public:
83  BlockC(PoolPtr pool, const ParameterSet& parameters) : Module(pool, parameters.getModuleName()) {
84  sqrt_s = parameters.globalParameters().get<double>("energy");
85  pT_is_met = parameters.get<bool>("pT_is_met", false);
86 
87  s12 = get<double>(parameters.get<InputTag>("s12"));
88  s123 = get<double>(parameters.get<InputTag>("s123"));
89 
90  m1 = parameters.get<double>("m1", 0.);
91 
92  p2 = get<LorentzVector>(parameters.get<InputTag>("p2"));
93  p3 = get<LorentzVector>(parameters.get<InputTag>("p3"));
94 
95  if (parameters.exists("branches")) {
96  auto branches_tags = parameters.get<std::vector<InputTag>>("branches");
97  for (auto& t : branches_tags)
98  m_branches.push_back(get<LorentzVector>(t));
99  }
100 
101  // If the met input is specified, get it, otherwise retrieve default
102  // one ("met::p4")
103  InputTag met_tag;
104  if (parameters.exists("met")) {
105  met_tag = parameters.get<InputTag>("met");
106  } else {
107  met_tag = InputTag({"met", "p4"});
108  }
109  m_met = get<LorentzVector>(met_tag);
110  };
111 
112  virtual Status work() override {
113 
114  solutions->clear();
115 
116  const double p2Sq = p2->M2();
117 
118  // Don't spend time on unphysical corner of the phase-space
119  if (*s12 >= *s123 || *s123 >= SQ(sqrt_s) || *s12 <= p2Sq + SQ(m1))
120  return Status::NEXT;
121 
122  // pT will be used to fix the transverse momentum of the reconstructed neutrinos
123  // We can either enforce momentum conservation by disregarding the MET, ie:
124  // pT = sum of all the visible particles,
125  // Or we can fix it using the MET given as input:
126  // pT = -MET
127  // In the latter case, it is the user's job to ensure momentum conservation at
128  // the matrix element level (by using the Boost module, for instance).
129  LorentzVector pT;
130  if (pT_is_met) {
131  pT = -*m_met;
132  } else {
133  pT = *p2;
134  for (size_t i = 0; i < m_branches.size(); i++) {
135  pT += *m_branches[i];
136  }
137  }
138 
139  // p1x = alpha1 E1 + beta1 ALPHA + gamma1
140  // p1y = ...(2)
141  // p1z = ...(3)
142  // E3 = ...(4)
143  const double cosphi3 = std::cos(p3->Phi());
144  const double sinphi3 = std::sin(p3->Phi());
145  const double costhe3 = std::cos(p3->Theta());
146  const double sinthe3 = std::sin(p3->Theta());
147 
148  const double E2 = p2->E();
149  const double p2x = p2->Px();
150  const double p2y = p2->Py();
151  const double p2z = p2->Pz();
152 
153  const double pTx = pT.Px();
154  const double pTy = pT.Py();
155 
156  // Term appears regularly, compute once.
157  const double X = p2x * sinthe3 * cosphi3 + p2y * sinthe3 * sinphi3 + p2z * costhe3;
158 
159  // Denominator that appears in several of the follwing eq.
160  // No need to compute it multiple times
161  const double denom = 2. * (E2 - X);
162 
163  const double beta1 = (cosphi3 * sinthe3) / denom;
164  const double gamma1 =
165  -(2 * E2 * pTx - 2 * X * pTx - *s12 * cosphi3 * sinthe3 + *s123 * cosphi3 * sinthe3) / denom;
166 
167  const double beta2 = (sinthe3 * sinphi3) / denom;
168  const double gamma2 =
169  -(2 * E2 * pTy - 2 * X * pTy - *s12 * sinthe3 * sinphi3 + *s123 * sinthe3 * sinphi3) / denom;
170 
171  const double alpha3 = E2 / p2z;
172  const double beta3 = (p2x * cosphi3 * sinthe3 + p2y * sinthe3 * sinphi3) / (-p2z * denom);
173  const double gamma3 = 0.5 *
174  (-*s12 + SQ(m1) + p2Sq + 2 * p2x * (pTx + sinthe3 * cosphi3 * (*s123 - *s12) / denom) +
175  2 * p2y * (pTy + sinthe3 * sinphi3 * (*s123 - *s12) / denom)) /
176  p2z;
177 
178  const double beta4 = -1. / denom;
179  const double gamma4 = (*s123 - *s12) / denom;
180 
181  // a11 E1^2 + a22 ALPHA^2 + a12 E1*ALPHA + a10 E1 + a01 ALPHA + a00 = 0
182  // id. with bij
183  const double a11 = SQ(alpha3) - 1;
184  const double a22 = SQ(beta1) + SQ(beta2) + SQ(beta3);
185  const double a12 = 2. * (alpha3 * beta3);
186  const double a10 = 2. * (alpha3 * gamma3);
187  const double a01 = 2. * (beta1 * gamma1 + beta2 * gamma2 + beta3 * gamma3);
188  const double a00 = SQ(gamma1) + SQ(gamma2) + SQ(gamma3) + SQ(m1);
189 
190  const double b11 = 0;
191  const double b22 = beta4 * (-beta1 * sinthe3 * cosphi3 - beta2 * sinthe3 * sinphi3 - beta3 * costhe3);
192  const double b12 = beta4 - alpha3 * beta4 * costhe3;
193  const double b10 = gamma4 - alpha3 * gamma4 * costhe3;
194  const double b01 = -0.5 - (beta1 * gamma4 + beta4 * gamma1) * sinthe3 * cosphi3 -
195  (beta2 * gamma4 + beta4 * gamma2) * sinthe3 * sinphi3 -
196  (beta3 * gamma4 + beta4 * gamma3) * costhe3;
197  const double b00 = gamma4 * (-gamma1 * sinthe3 * cosphi3 - gamma2 * sinthe3 * sinphi3 - gamma3 * costhe3);
198 
199  // Find the intersection of the 2 conics (at most 4 real solutions for (e1,ALPHA))
200  std::vector<double> e1, ALPHA;
201  solve2Quads(a11, a22, a12, a10, a01, a00, b11, b22, b12, b10, b01, b00, e1, ALPHA, false);
202 
203  // For each solution (e1,ALPHA), find the neutrino 4-momentum p1
204  if (e1.size() == 0)
205  return Status::NEXT;
206 
207  for (unsigned int i = 0; i < e1.size(); i++) {
208  const double E1 = e1.at(i);
209  const double alp = ALPHA.at(i);
210 
211  //Make sure E1 is not negative
212  if (E1 <= 0.)
213  continue;
214 
215  const double E3 = beta4 * alp + gamma4;
216  // Make sure E3 is not negative
217  if (E3 <= 0.)
218  continue;
219 
220  const double p1x = beta1 * alp + gamma1;
221  const double p1y = beta2 * alp + gamma2;
222  const double p1z = alpha3 * E1 + beta3 * alp + gamma3;
223 
224  LorentzVector p1(p1x, p1y, p1z, E1);
225 
226  const double p3x = E3 * sinthe3 * cosphi3;
227  const double p3y = E3 * sinthe3 * sinphi3;
228  const double p3z = E3 * costhe3;
229 
230  LorentzVector p3_sol(p3x, p3y, p3z, E3);
231 
232  // Check if solutions are physical
233  LorentzVector tot = p1 + *p2 + p3_sol;
234  for (size_t i = 0; i < m_branches.size(); i++) {
235  tot += *m_branches[i];
236  }
237  const double q1Pz = std::abs(tot.Pz() + tot.E()) / 2.;
238  const double q2Pz = std::abs(tot.Pz() - tot.E()) / 2.;
239  if (q1Pz > sqrt_s / 2 || q2Pz > sqrt_s / 2)
240  continue;
241 
242  if (!ApproxComparison((p1 + p3_sol + pT).Pt(), 0.)) {
243 #ifndef NDEBUG
244  LOG(trace) << "[BlockC] Throwing solution because total Pt is incorrect. "
245  << "Expected " << 0. << ", got " << (p1 + p3_sol + pT).Pt();
246 #endif
247  continue;
248  }
249 
250  if (!ApproxComparison(p1.M() / p1.E(), m1 / p1.E())) {
251 #ifndef NDEBUG
252  LOG(trace) << "[BlockC] Throwing solution because p1 has an invalid mass. " <<
253  "Expected " << m1 << ", got " << p1.M();
254 #endif
255  continue;
256  }
257 
258  if (!ApproxComparison((p1 + *p2).M2(), *s12)) {
259 #ifndef NDEBUG
260  LOG(trace) << "[BlockC] Throwing solution because of invalid invariant mass. " <<
261  "Expected " << *s12 << ", got " << (p1 + *p2).M2();
262 #endif
263  continue;
264  }
265 
266  if (!ApproxComparison((p1 + *p2 + p3_sol).M2(), *s123)) {
267 #ifndef NDEBUG
268  LOG(trace) << "[BlockC] Throwing solution because of invalid invariant mass. " <<
269  "Expected " << *s123 << ", got " << (p1 + *p2 + p3_sol).M2();
270 #endif
271  continue;
272  }
273 
274  const double jacobian = SQ(E3) * sinthe3 / (32 * SQ(M_PI) * SQ(sqrt_s) *
275  std::abs((p3_sol.Dot(p1 + *p2) + SQ(p3x) + SQ(p3y)) * (E2 * p1z - E1 * p2z) + p3x * p3z * (E1 * p2x - E2 * p1x)
276  + p3x * E3 * (p1x * p2z - p1z * p2x) + p3y * p3z * (E1 * p2y - E2 * p1y) + E3 * p3y * (p1y * p2z - p1z * p2y)));
277 
278  Solution s {{p1, p3_sol}, jacobian, true};
279  solutions->push_back(s);
280  }
281 
282  return solutions->size() > 0 ? Status::OK : Status::NEXT;
283  }
284 
285 private:
286  double sqrt_s;
287  bool pT_is_met;
288  double m1;
289 
290  // Inputs
291  Value<double> s12;
292  Value<double> s123;
293  std::vector<Value<LorentzVector>> m_branches;
294  Value<LorentzVector> m_met;
297 
298  // Outputs
299  std::shared_ptr<SolutionCollection> solutions = produce<SolutionCollection>("solutions");
300 };
301 
302 REGISTER_MODULE(BlockC)
303  .Input("s12")
304  .Input("s123")
305  .Input("p2")
306  .Input("p3")
307  .OptionalInputs("branches")
308  .Input("met=met::p4")
309  .Output("solutions")
310  .GlobalAttr("energy:double")
311  .Attr("pT_is_met:bool=false")
312  .Attr("m1:double=0");
313 
bool ApproxComparison(double value, double expected)
Compare two doubles and return true if they are approximatively equal.
Definition: Math.cc:409
Generic solution structure representing a set of particles, along with its jacobian.
Definition: Solution.h:28
Mathematical functions.
virtual Status work() override
Main function.
Definition: BlockC.cc:112
An identifier of a module&#39;s output.
Definition: InputTag_fwd.h:37
Parent class for all the modules.
Definition: Module.h:37
A class encapsulating a lua table.
Definition: ParameterSet.h:82
Final (main) Block C, describing
Definition: BlockC.cc:81
#define SQ(x)
Compute .
Definition: Math.h:25
Module(PoolPtr pool, const std::string &name)
Constructor.
Definition: Module.h:61
bool solve2Quads(const double a20, const double a02, const double a11, const double a10, const double a01, const double a00, const double b20, const double b02, const double b11, const double b10, const double b01, const double b00, std::vector< double > &E1, std::vector< double > &E2, bool verbose=false)
Solve a system of two quadratic equations.
Definition: Math.cc:192