lua.cc
Go to the documentation of this file.
1 /*
2  * MoMEMta: a modular implementation of the Matrix Element Method
3  * Copyright (C) 2016 Universite catholique de Louvain (UCL), Belgium
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  */
18 
25 #include <catch.hpp>
26 
27 #include <iostream>
28 #include <string>
29 #include <vector>
30 
31 #include <momemta/InputTag.h>
32 #include <momemta/ILuaCallback.h>
33 #include <momemta/Logging.h>
34 #include <momemta/ModuleFactory.h>
35 #include <momemta/ModuleRegistry.h>
36 #include <momemta/ParameterSet.h>
37 
38 #include <ExecutionPath.h>
39 #include <lua/LazyTable.h>
40 #include <lua/ParameterSetParser.h>
41 #include <lua/bindings/Path.h>
42 #include <lua/bindings/Types.h>
43 #include <lua/utils.h>
44 
45 void execute_string(std::shared_ptr<lua_State> L, const std::string& code) {
46  if (luaL_dostring(L.get(), code.c_str())) {
47  std::string error = lua_tostring(L.get(), -1);
48  FAIL(error);
49  }
50 }
51 
53  public:
54  LuaCallbackMock(): n_dimensions(0) {}
55 
56  virtual void onModuleDeclared(const std::string& type, const std::string& name) override {
57  modules.push_back({type, name});
58  }
59 
60  virtual void onIntegrandDeclared(const InputTag& tag) override {
61  integrands.push_back(tag);
62  }
63 
64  virtual void onNewPath(const ExecutionPath& path) override {
65  paths.push_back(path);
66  }
67 
68  virtual void addIntegrationDimension() override {
69  n_dimensions++;
70  }
71 
72  virtual void onNewInputDeclared(const std::string& name) override {
73  inputs.push_back(name);
74  }
75 
76  std::vector<std::pair<std::string, std::string>> modules;
77  std::vector<InputTag> integrands;
78  std::vector<ExecutionPath> paths;
79  std::size_t n_dimensions;
80  std::vector<std::string> inputs;
81 };
82 
83 // A small mock of LazyParameterSet to change visibility of the `freeze` function
85  using lua::LazyTable::LazyTable;
86 
87  public:
88  virtual void freeze() override {
89  lua::LazyTable::freeze();
90  }
91 };
92 
93 TEST_CASE("lua parsing utilities", "[lua]") {
94 
95  // Suppress log messages
96  auto default_log_level = logging::level::warning;
97  logging::set_level(default_log_level);
98 
99  LuaCallbackMock luaCallback;
100  REQUIRE(luaCallback.modules.empty());
101  std::shared_ptr<lua_State> L = lua::init_runtime(&luaCallback);
102 
103  auto stack_size = lua_gettop(L.get());
104 
105  SECTION("custom functions") {
106 
107  logging::set_level(logging::level::fatal);
108  execute_string(L, "load_modules('not_existing.so')");
109  logging::set_level(default_log_level);
110 
111  execute_string(L, "parameter('not_existing')");
112 
113  // Check that the 'add_dimension()' function returns the correct InputTag
114  // and that the index gets correctly incremented at each call.
115  execute_string(L, "index1 = add_dimension()");
116  lua_getglobal(L.get(), "index1");
117  auto value = lua::to_any(L.get(), -1);
118  REQUIRE( (momemta::any_cast<InputTag>(value.first)).toString() == "cuba::ps_points/1");
119  execute_string(L, "index2 = add_dimension()");
120  lua_getglobal(L.get(), "index2");
121  value = lua::to_any(L.get(), -1);
122  REQUIRE( (momemta::any_cast<InputTag>(value.first)).toString() == "cuba::ps_points/2");
123  lua_pop(L.get(), 2);
124  // 'add_dimension()' has been called twice, so we should have two dimension in the configuation:
125  REQUIRE( luaCallback.n_dimensions == 2 );
126 
127  execute_string(L, "integrand('integrand1::output', 'integrand2::output')");
128  REQUIRE( luaCallback.integrands.size() == 2 );
129  REQUIRE( luaCallback.integrands.at(1).toString() == "integrand2::output" );
130  }
131 
132  SECTION("defining modules") {
133  execute_string(L, "BreitWignerGenerator.test = {}");
134  REQUIRE(luaCallback.modules.size() == 1);
135  REQUIRE(luaCallback.modules.back().first == "BreitWignerGenerator");
136  REQUIRE(luaCallback.modules.back().second == "test");
137 
138  execute_string(L, "BreitWignerGenerator.test2 = {}");
139  REQUIRE(luaCallback.modules.size() == 2);
140  REQUIRE(luaCallback.modules.back().second == "test2");
141  }
142 
143  SECTION("loading modules") {
144  momemta::ModuleList modules;
145  momemta::ModuleRegistry::get().exportList(true, modules);
146 
147  auto n_modules = modules.size();
148 
149  execute_string(L, "load_modules('libempty_module.so')");
150 
151  momemta::ModuleRegistry::get().exportList(true, modules);
152 
153  REQUIRE(modules.size() == n_modules + 1);
154  }
155 
156  SECTION("parsing values") {
157  // Integer
158  lua_pushinteger(L.get(), 42);
159  auto value = lua::to_any(L.get(), -1);
160  REQUIRE(value.first.type() == typeid(int64_t));
161  REQUIRE(momemta::any_cast<int64_t>(value.first) == 42);
162  REQUIRE_FALSE(value.second);
163  lua_pop(L.get(), 1);
164 
165  // Double
166  lua_pushnumber(L.get(), 38.5);
167  value = lua::to_any(L.get(), -1);
168  REQUIRE(value.first.type() == typeid(double));
169  REQUIRE(momemta::any_cast<double>(value.first) == Approx(38.5));
170  REQUIRE_FALSE(value.second);
171  lua_pop(L.get(), 1);
172 
173  // Boolean
174  lua_pushboolean(L.get(), true);
175  value = lua::to_any(L.get(), -1);
176  REQUIRE(value.first.type() == typeid(bool));
177  REQUIRE(momemta::any_cast<bool>(value.first) == true);
178  REQUIRE_FALSE(value.second);
179  lua_pop(L.get(), 1);
180 
181  // std::string
182  lua_pushliteral(L.get(), "lua is fun");
183  value = lua::to_any(L.get(), -1);
184  REQUIRE(value.first.type() == typeid(std::string));
185  REQUIRE(momemta::any_cast<std::string>(value.first) == "lua is fun");
186  REQUIRE_FALSE(value.second);
187  lua_pop(L.get(), 1);
188 
189  // Double vector
190  execute_string(L, "return {0.1, 0.2, 0.3}");
191  value = lua::to_any(L.get(), -1);
192  REQUIRE(value.first.type() == typeid(std::vector<double>));
193  REQUIRE_FALSE(value.second);
194  {
195  auto v = momemta::any_cast<std::vector<double>>(value.first);
196  REQUIRE(v.size() == 3);
197  REQUIRE(v[0] == Approx(0.1));
198  REQUIRE(v[1] == Approx(0.2));
199  REQUIRE(v[2] == Approx(0.3));
200  }
201  lua_pop(L.get(), 1);
202 
203  // Double vector with an integer inside
204  execute_string(L, "return {0.1, 2, 0.3}");
205  value = lua::to_any(L.get(), -1);
206  REQUIRE(value.first.type() == typeid(std::vector<double>));
207  REQUIRE_FALSE(value.second);
208  {
209  auto v = momemta::any_cast<std::vector<double>>(value.first);
210  REQUIRE(v.size() == 3);
211  REQUIRE(v[0] == Approx(0.1));
212  REQUIRE(v[1] == Approx(2));
213  REQUIRE(v[2] == Approx(0.3));
214  }
215  lua_pop(L.get(), 1);
216 
217  // Integer vector
218  execute_string(L, "return {1, 2, 3}");
219  value = lua::to_any(L.get(), -1);
220  REQUIRE(value.first.type() == typeid(std::vector<int64_t>));
221  REQUIRE_FALSE(value.second);
222  {
223  auto v = momemta::any_cast<std::vector<int64_t>>(value.first);
224  REQUIRE(v.size() == 3);
225  REQUIRE(v[0] == 1);
226  REQUIRE(v[1] == 2);
227  REQUIRE(v[2] == 3);
228  }
229  lua_pop(L.get(), 1);
230 
231  // Invalid array
232  execute_string(L, "return {1, 'string', false}");
233  REQUIRE_THROWS_AS(lua::to_any(L.get(), -1), lua::invalid_array_error);
234  lua_pop(L.get(), 1);
235  }
236 
237  SECTION("parsing lazy values") {
238 
239  // Setup global parameters table
240  execute_string(L, "parameters = { top_mass = 173. }");
241 
242  SECTION("lazy function") {
243  execute_string(L, "return parameter('top_mass')");
244  auto value = lua::to_any(L.get(), -1);
245  REQUIRE(value.second == true);
246  REQUIRE(value.first.type() == typeid(lua::LazyFunction));
247  auto fct = momemta::any_cast<lua::LazyFunction>(value.first);
248  auto fct_evaluated = fct();
249  REQUIRE(fct_evaluated.type() == typeid(double));
250  REQUIRE(momemta::any_cast<double>(fct_evaluated) == Approx(173.));
251  lua_pop(L.get(), 1);
252  }
253 
254  SECTION("lazy function after modification of parameter") {
255  execute_string(L, "return parameter('top_mass')");
256  auto value = lua::to_any(L.get(), -1);
257  REQUIRE(value.second == true);
258  REQUIRE(value.first.type() == typeid(lua::LazyFunction));
259 
260  // Edit parameter
261  lua_getglobal(L.get(), "parameters");
262  lua_pushnumber(L.get(), 175.);
263  lua_setfield(L.get(), -2, "top_mass");
264  lua_pop(L.get(), 1);
265 
266  auto fct = momemta::any_cast<lua::LazyFunction>(value.first);
267  auto fct_evaluated = fct();
268  REQUIRE(fct_evaluated.type() == typeid(double));
269  REQUIRE(momemta::any_cast<double>(fct_evaluated) == Approx(175.));
270  lua_pop(L.get(), 1);
271  }
272 
273  SECTION("lazy table field") {
274  lua::LazyTableField lazy(L.get(), "parameters", "top_mass");
275 
276  SECTION("evaluation") {
277  auto value = lazy();
278  REQUIRE(value.type() == typeid(double));
279  REQUIRE(momemta::any_cast<double>(value) == Approx(173.));
280  }
281 
282  SECTION("edition") {
283  lazy.set(175.);
284 
285  auto value = lazy();
286  REQUIRE(value.type() == typeid(double));
287  REQUIRE(momemta::any_cast<double>(value) == Approx(175.));
288  }
289  }
290  }
291 
292  SECTION("ParameterSet evaluation") {
293  auto def = R"(test_table = {
294  integer = 1,
295  float = 10.,
296  string = "test",
297  inputtag = "module::parameter",
298  vector = {0, 1, 2, 3}
299 })";
300 
301  execute_string(L, def);
302 
303  int type = lua_getglobal(L.get(), "test_table");
304  REQUIRE(type == LUA_TTABLE);
305 
306  ParameterSet p;
307  ParameterSetParser::parse(p, L.get(), -1);
308 
309  REQUIRE(p.existsAs<int64_t>("integer"));
310  REQUIRE(p.get<int64_t>("integer") == 1);
311 
312  REQUIRE(p.existsAs<double>("float"));
313  REQUIRE(p.get<double>("float") == Approx(10.));
314 
315  REQUIRE(p.existsAs<std::string>("string"));
316  REQUIRE(p.get<std::string>("string") == "test");
317 
318  auto i = InputTag("module", "parameter");
319  REQUIRE(p.existsAs<InputTag>("inputtag"));
320  REQUIRE(p.get<InputTag>("inputtag") == i);
321 
322  REQUIRE(p.existsAs<std::vector<int64_t>>("vector"));
323  auto v = p.get<std::vector<int64_t>>("vector");
324  REQUIRE(v.size() == 4);
325  REQUIRE(v[0] == 0);
326  REQUIRE(v[1] == 1);
327  REQUIRE(v[2] == 2);
328  REQUIRE(v[3] == 3);
329 
330  lua_pop(L.get(), 1);
331  }
332 
333  SECTION("LazyParameterSet evaluation") {
334  auto def = R"(test_table = {
335  integer = 1,
336  float = 10.,
337  string = "test",
338  inputtag = "module::parameter"
339 })";
340 
341  execute_string(L, def);
342 
343  int type = lua_getglobal(L.get(), "test_table");
344  REQUIRE(type == LUA_TTABLE);
345 
346  LazyTableMock p(L, "test_table");
347  ParameterSetParser::parse(p, L.get(), -1);
348 
349  auto f = p;
350  f.freeze();
351 
352  REQUIRE(f.existsAs<int64_t>("integer"));
353  REQUIRE(f.get<int64_t>("integer") == 1);
354 
355  REQUIRE(f.existsAs<double>("float"));
356  REQUIRE(f.get<double>("float") == Approx(10.));
357 
358  REQUIRE(f.existsAs<std::string>("string"));
359  REQUIRE(f.get<std::string>("string") == "test");
360 
361  auto i = InputTag("module", "parameter");
362  REQUIRE(f.existsAs<InputTag>("inputtag"));
363  REQUIRE(f.get<InputTag>("inputtag") == i);
364 
365  // Edit the parameter set, and refreeze
366 
367  // Change value
368  p.set("integer", 10);
369 
370  // Change value AND type
371  p.set("float", true);
372 
373  // Add new value
374  p.set("new", 125.);
375 
376  f = p;
377  f.freeze();
378 
379  REQUIRE(f.existsAs<int64_t>("integer"));
380  REQUIRE(f.get<int64_t>("integer") == 10);
381 
382  REQUIRE(f.existsAs<bool>("float"));
383  REQUIRE(f.get<bool>("float") == true);
384 
385  REQUIRE(f.existsAs<std::string>("string"));
386  REQUIRE(f.get<std::string>("string") == "test");
387 
388  REQUIRE(f.existsAs<InputTag>("inputtag"));
389  REQUIRE(f.get<InputTag>("inputtag") == i);
390 
391  REQUIRE(f.existsAs<double>("new"));
392  REQUIRE(f.get<double>("new") == Approx(125));
393 
394  lua_pop(L.get(), 1);
395  }
396 
397  SECTION("LazyParameterSet with non-existing table") {
398 
399  int type = lua_getglobal(L.get(), "test_table");
400  lua_pop(L.get(), 1);
401  REQUIRE(type == LUA_TNIL);
402 
403  LazyTableMock p(L, "test_table");
404 
405  // Table must not exist
406  type = lua_getglobal(L.get(), "test_table");
407  lua_pop(L.get(), 1);
408  REQUIRE(type == LUA_TNIL);
409 
410  p.set("key", "value");
411 
412  // Table must have been created
413  type = lua_getglobal(L.get(), "test_table");
414  lua_pop(L.get(), 1);
415  REQUIRE(type == LUA_TTABLE);
416 
417  p.freeze();
418 
419  REQUIRE(p.existsAs<std::string>("key"));
420  REQUIRE(p.get<std::string>("key") == "value");
421  }
422 
423  SECTION("Path") {
424  auto def = R"(path = Path("a", "b", "c"))";
425  execute_string(L, def);
426 
427  auto type = lua_getglobal(L.get(), "path");
428  REQUIRE(type == LUA_TUSERDATA);
429 
430  std::string type_name = get_custom_type_name(L.get(), -1);
431  REQUIRE(type_name == LUA_PATH_TYPE_NAME);
432 
433  ExecutionPath* path = lua::path_get(L.get(), -1);
434  REQUIRE(path != nullptr);
435 
436  REQUIRE(path->elements.size() == 3);
437  REQUIRE(path->elements[0] == "a");
438  REQUIRE(path->elements[1] == "b");
439  REQUIRE(path->elements[2] == "c");
440 
441  lua_pop(L.get(), 1);
442  }
443 
444  SECTION("Path to momemta::any") {
445  auto def = R"(path = Path("a"))";
446  execute_string(L, def);
447 
448  auto type = lua_getglobal(L.get(), "path");
449  REQUIRE(type == LUA_TUSERDATA);
450 
451  auto path = get_custom_type_ptr(L.get(), -1);
452  REQUIRE(path.type() == typeid(ExecutionPath));
453 
454  lua_pop(L.get(), 1);
455  }
456 
457  REQUIRE(stack_size == lua_gettop(L.get()));
458 }
Notification callback used for communication between the lua file and MoMEMta.
Definition: ILuaCallback.h:30
virtual void onNewPath(const ExecutionPath &path) override
A new path is declared in the configuration file.
Definition: lua.cc:64
A specialization of ParameterSet for lazy loading of lua tables.
Definition: LazyTable.h:62
std::shared_ptr< lua_State > init_runtime(ILuaCallback *callback)
Initialize the lua runtime.
Definition: utils.cc:554
virtual void addIntegrationDimension() override
A new integration dimension is requested in the configuration file.
Definition: lua.cc:68
virtual void onIntegrandDeclared(const InputTag &tag) override
The integrand was defined in the configuration file.
Definition: lua.cc:60
Lua binding of C++ Path class.
ExecutionPath * path_get(lua_State *L, int index)
Retrieve an instance of Path from the lua stack.
Definition: Path.cc:83
Lazy table field in lua (delayed table access)
Definition: LazyTable.h:35
std::enable_if< std::is_same< T, bool >::value||std::is_same< T, InputTag >::value >::type set(const std::string &name, const T &value)
Change the value of a given parameter. If the parameter does not exist, it&#39;s first created...
Definition: ParameterSet.h:160
static ModuleRegistry & get()
A singleton available at startup.
static void parse(ParameterSet &p, lua_State *L, int index)
Convert a lua table to a ParameterSet.
virtual void onModuleDeclared(const std::string &type, const std::string &name) override
A module is declared in the configuration file.
Definition: lua.cc:56
An identifier of a module&#39;s output.
Definition: InputTag_fwd.h:37
Generic functions to deal with custom lua types.
A class encapsulating a lua table.
Definition: ParameterSet.h:82
momemta::any get_custom_type_ptr(lua_State *L, int index)
Convert a lua custom table to a momemta::any value.
Definition: Types.cc:48
std::string get_custom_type_name(lua_State *L, int index)
Get the type of a custom table.
Definition: Types.cc:35
void exportList(bool ignore_internal, ModuleList &list) const
std::pair< momemta::any, bool > to_any(lua_State *L, int index)
Convert a lua type to momemta::any.
Definition: utils.cc:191
virtual void onNewInputDeclared(const std::string &name) override
The configuration file declared a new input.
Definition: lua.cc:72
Lazy function in lua (delayed function evaluation)
Definition: utils.h:72