utils.cc
1 /*
2  * MoMEMta: a modular implementation of the Matrix Element Method
3  * Copyright (C) 2016 Universite catholique de Louvain (UCL), Belgium
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  */
18 
19 #include <lua/utils.h>
20 
21 #include <cassert>
22 #include <regex>
23 
24 #include <momemta/InputTag.h>
25 #include <momemta/ILuaCallback.h>
26 #include <momemta/Logging.h>
27 #include <momemta/ModuleRegistry.h>
28 #include <momemta/ParameterSet.h>
29 #include <momemta/Utils.h>
30 
31 #include <LibraryManager.h>
32 #include <lua/ParameterSetParser.h>
33 #include <lua/bindings/Path.h>
34 #include <lua/bindings/Types.h>
35 
36 // Defined by `embedLua.py` at build-time
37 extern void execute_embed_lua_code(lua_State*);
38 
39 namespace lua {
40 
41  Lazy::Lazy(lua_State* L) {
42  this->L = L;
43  }
44 
45  LazyFunction::LazyFunction(lua_State* L, int index): Lazy(L) {
46  auto absolute_index = get_index(L, index);
47 
48  // Duplicate the function on the top of the stack. This ensure the stack size won't change
49  lua_pushvalue(L, absolute_index);
50 
51  // Pop the anonymous function from the stack, and store it in the global lua registry
52  ref_index = luaL_ref(L, LUA_REGISTRYINDEX);
53  }
54 
56 
57  LOG(trace) << "[LazyFunction::operator()] >> stack size = " << lua_gettop(L);
58 
59  // Pop the anonymous function from the registry, and push it on the top of the stack
60  lua_rawgeti(L, LUA_REGISTRYINDEX, ref_index);
61 
62  // Call the function. The function removed from the stack, and the return value is pushed on the top of the stack
63  auto result = lua_pcall(L, 0, 1, 0);
64  if (result != LUA_OK) {
65  std::string error = lua_tostring(L, -1);
66  LOG(fatal) << "Fail to call lua anonymous function. Return value is " << result << ". Error message: " << error;
67  }
68 
69  momemta::any value;
70  bool lazy = false;
71  std::tie(value, lazy) = to_any(L, -1);
72  assert(!lazy);
73 
74  lua_pop(L, 1);
75 
76  LOG(trace) << "[LazyFunction::operator()] << stack size = " << lua_gettop(L);
77 
78  return value;
79  }
80 
81  Type type(lua_State* L, int index) {
82  int t = lua_type(L, index);
83 
84  switch (t) {
85  case LUA_TBOOLEAN:
86  return BOOLEAN;
87  break;
88 
89  case LUA_TSTRING: {
90  std::string value = lua_tostring(L, index);
91  if (InputTag::isInputTag(value))
92  return INPUT_TAG;
93  else
94  return STRING;
95  } break;
96 
97  case LUA_TNUMBER:
98  if (lua_isinteger(L, index))
99  return INTEGER;
100  else
101  return REAL;
102  break;
103 
104  case LUA_TTABLE: {
105  // We only support ParameterSet for table
106  if (lua_is_array(L, index) == -1) {
107  return PARAMETER_SET;
108  }
109 
110  } break;
111  }
112 
113  return NOT_SUPPORTED;
114  }
115 
116  size_t get_index(lua_State* L, int index) {
117  return (index < 0) ? lua_gettop(L) + index + 1 : index;
118  }
119 
125  int lua_is_array(lua_State* L, int index) {
126 
127  LOG(trace) << "[lua_is_array] >> stack size = " << lua_gettop(L);
128 
129  size_t table_index = get_index(L, index);
130 
131  if (lua_type(L, table_index) != LUA_TTABLE)
132  return -1;
133 
134  size_t size = 0;
135 
136  lua_pushnil(L);
137  while (lua_next(L, table_index) != 0) {
138  if (lua_type(L, -2) != LUA_TNUMBER) {
139  lua_pop(L, 2);
140  LOG(trace) << "[lua_is_array] << stack size = " << lua_gettop(L);
141  return -1;
142  }
143 
144  size++;
145 
146  lua_pop(L, 1);
147  }
148 
149  LOG(trace) << "[lua_is_array] << stack size = " << lua_gettop(L);
150  return size;
151  }
152 
158  Type lua_array_unique_type(lua_State* L, int index) {
159 
160  size_t absolute_index = get_index(L, index);
161 
162  if (lua_type(L, absolute_index) != LUA_TTABLE)
163  return NOT_SUPPORTED;
164 
165  Type result = NOT_SUPPORTED;
166 
167  lua_pushnil(L);
168  while (lua_next(L, absolute_index) != 0) {
169  if (result == NOT_SUPPORTED) {
170  result = type(L, -1);
171  } else {
172  Type entry_type = type(L, -1);
173 
174  if ((result == INTEGER) && (entry_type == REAL))
175  result = REAL;
176  else if ((result == REAL) && (entry_type == INTEGER))
177  result = REAL;
178  else if (result != entry_type) {
179  lua_pop(L, 2);
180  result = NOT_SUPPORTED;
181  break;
182  }
183  }
184 
185  lua_pop(L, 1);
186  }
187 
188  return result;
189  }
190 
191  std::pair<momemta::any, bool> to_any(lua_State* L, int index) {
192 
193  LOG(trace) << "[to_any] >> stack size = " << lua_gettop(L);
194  size_t absolute_index = get_index(L, index);
195 
196  momemta::any result;
197  bool lazy = false;
198 
199  auto type = lua_type(L, absolute_index);
200  switch (type) {
201  case LUA_TNUMBER: {
202  if (lua_isinteger(L, absolute_index)) {
203  int64_t number = lua_tointeger(L, absolute_index);
204  result = number;
205  } else {
206  double number = lua_tonumber(L, absolute_index);
207  result = number;
208  }
209  } break;
210 
211  case LUA_TBOOLEAN: {
212  bool value = lua_toboolean(L, absolute_index);
213  result = value;
214  } break;
215 
216  case LUA_TSTRING: {
217  std::string value = lua_tostring(L, absolute_index);
218  if (InputTag::isInputTag(value)) {
219  InputTag tag = InputTag::fromString(value);
220  result = tag;
221  } else {
222  result = value;
223  }
224  } break;
225 
226  case LUA_TTABLE: {
227  LOG(trace) << "[to_any::table] >> stack size = " << lua_gettop(L);
228  if (lua::lua_is_array(L, absolute_index) > 0) {
229 
231 
232  if ((type = lua::lua_array_unique_type(L, absolute_index)) != NOT_SUPPORTED) {
233  result = to_vector(L, absolute_index, type);
234  } else {
235  throw invalid_array_error("Various types stored into the array. This is not supported.");
236  }
237 
238  } else {
239  ParameterSet cfg;
240  ParameterSetParser::parse(cfg, L, absolute_index);
241  result = cfg;
242  }
243  LOG(trace) << "[to_any::table] << stack size = " << lua_gettop(L);
244  } break;
245 
246  case LUA_TFUNCTION: {
247  LOG(trace) << "[to_any::function] >> stack size = " << lua_gettop(L);
248 
249  result = LazyFunction(L, absolute_index);
250  lazy = true;
251 
252  LOG(trace) << "[to_any::function] << stack size = " << lua_gettop(L);
253  } break;
254 
255  case LUA_TUSERDATA: {
256  result = get_custom_type_ptr(L, absolute_index);
257 
258  } break;
259 
260  default: {
261  LOG(fatal) << "Unsupported lua type: " << lua_type(L, absolute_index);
263  } break;
264  }
265 
266  LOG(trace) << "[to_any] << final type = " << demangle(result.type().name());
267  LOG(trace) << "[to_any] << stack size = " << lua_gettop(L);
268  return {result, lazy};
269  }
270 
271  void push_any(lua_State* L, const momemta::any& value) {
272  LOG(trace) << "[push_any] >> stack size = " << lua_gettop(L);
273 
274  if (value.type() == typeid(int64_t)) {
275  int64_t v = momemta::any_cast<int64_t>(value);
276  lua_pushinteger(L, v);
277  } else if (value.type() == typeid(double)) {
278  double v = momemta::any_cast<double>(value);
279  lua_pushnumber(L, v);
280  } else if (value.type() == typeid(bool)) {
281  bool v = momemta::any_cast<bool>(value);
282  lua_pushboolean(L, v);
283  } else if (value.type() == typeid(std::string)) {
284  auto v = momemta::any_cast<std::string>(value);
285  lua_pushstring(L, v.c_str());
286  } else if (value.type() == typeid(InputTag)) {
287  auto v = momemta::any_cast<InputTag>(value).toString();
288  lua_pushstring(L, v.c_str());
289  } else {
290  LOG(fatal) << "Unsupported C++ value: " << demangle(value.type().name());
291  throw lua::unsupported_type_error(demangle(value.type().name()));
292  }
293 
294  LOG(trace) << "[push_any] << stack size = " << lua_gettop(L);
295  }
296 
297  momemta::any to_vector(lua_State* L, int index, Type t) {
298  switch (t) {
299  case BOOLEAN:
300  return to_vectorT<bool>(L, index);
301 
302  case STRING:
303  return to_vectorT<std::string>(L, index);
304 
305  case INTEGER:
306  return to_vectorT<int64_t>(L, index);
307 
308  case REAL:
309  return to_vectorT<double>(L, index);
310 
311  case INPUT_TAG:
312  return to_vectorT<InputTag>(L, index);
313 
314  case PARAMETER_SET:
315  return to_vectorT<ParameterSet>(L, index);
316 
317  case NOT_SUPPORTED:
318  break;
319  }
320 
321  throw invalid_array_error("Unsupported array type");
322  }
323 
325  template<> double special_any_cast(const momemta::any& value) {
326  if (value.type() == typeid(int64_t))
327  return static_cast<double>(momemta::any_cast<int64_t>(value));
328 
329  return momemta::any_cast<double>(value);
330  }
331 
332  int module_table_newindex(lua_State* L) {
333  lua_getmetatable(L, 1);
334  lua_getfield(L, -1, "__type");
335 
336  const char* module_type = luaL_checkstring(L, -1);
337  const char* module_name = luaL_checkstring(L, 2);
338 
339  // Remove field name from stack
340  lua_pop(L, 1);
341 
342  // Validate module name
343  // Format is: [a-zA-Z][a-zA-Z0-9_]*
344  static std::regex name_regex("[a-zA-Z][a-zA-Z0-9_]*");
345 
346  if (! std::regex_match(module_name, name_regex)) {
347  luaL_error(L, "invalid module name '%s': valid format is [a-zA-Z][a-zA-Z0-9_]*", module_name);
348  }
349 
350  lua_getfield(L, -1, "__ptr");
351  void* cfg_ptr = lua_touserdata(L, -1);
352  ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
353 
354  callback->onModuleDeclared(module_type, module_name);
355 
356  // Remove metatable from the stack
357  lua_pop(L, 2);
358 
359  // Add "@name" and "@type" fields to the module's parameters
360 
361  // Push the key and then the value
362  lua_pushstring(L, "@name");
363  lua_pushstring(L, module_name);
364  lua_rawset(L, -3);
365 
366  lua_pushstring(L, "@type");
367  lua_pushstring(L, module_type);
368  lua_rawset(L, -3);
369 
370  // And actually set the value to the table
371  lua_rawset(L, 1);
372 
373  return 0;
374  }
375 
376  void register_modules(lua_State* L, void* ptr) {
377  momemta::ModuleList modules;
378  momemta::ModuleRegistry::get().exportList(true, modules);
379 
380  for (const auto& module: modules) {
381  const char* module_name = module.name.c_str();
382 
383  int type = lua_getglobal(L, module_name);
384  lua_pop(L, 1);
385  if (type != LUA_TNIL) {
386  // Global already exists
387  continue;
388  }
389 
390  // Create a new empty table
391  lua_newtable(L);
392 
393  std::string module_metatable = module.name + "_mt";
394 
395  // Create the associated metatable
396  luaL_newmetatable(L, module_metatable.c_str());
397 
398  lua_pushstring(L, module_name);
399  lua_setfield(L, -2, "__type");
400 
401  lua_pushlightuserdata(L, ptr);
402  lua_setfield(L, -2, "__ptr");
403 
404  // Set the metatable '__newindex' function
405  // This function is called when a assignment is made to the table
406  // In our case, it's called when a new module is declared
407  const luaL_Reg l[] = {
408  {"__newindex", lua::module_table_newindex},
409  {nullptr, nullptr}
410  };
411  luaL_setfuncs(L, l, 0);
412 
413  lua_setmetatable(L, -2);
414 
415  // And register it as a global variable
416  lua_setglobal(L, module_name);
417 
418  LOG(trace) << "Registered new lua global variable '" << module_name << "'";
419  }
420  }
421 
422  int load_modules(lua_State* L) {
423  int n = lua_gettop(L);
424  if (n != 1) {
425  luaL_error(L, "invalid number of arguments: 1 expected, got %d", n);
426  }
427 
428  void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(1));
429 
430  const char *path = luaL_checkstring(L, 1);
431  LibraryManager::get().registerLibrary(path);
432 
433  register_modules(L, cfg_ptr);
434 
435  return 0;
436  }
437 
438  int parameter(lua_State* L) {
439  int n = lua_gettop(L);
440  if (n != 1) {
441  luaL_error(L, "invalid number of arguments: 1 expected, got %d", n);
442  }
443 
444  std::string parameter_name = luaL_checkstring(L, 1);
445 
446  // Create an anonymous function return the value of the parameter
447  // Assumes there's a global table named `parameters`
448 
449  std::string code = "return function() return parameters['" + parameter_name + "'] end";
450  luaL_dostring(L, code.c_str());
451 
452  return 1;
453  }
454 
455  int set_final_module(lua_State* L) {
456  int n = lua_gettop(L);
457  if (n == 0) {
458  luaL_error(L, "invalid number of arguments: at least one expected, got 0");
459  }
460 
461  void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(1));
462  ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
463 
464  for(size_t i = 1; i <= size_t(n); i++) {
465  std::string input_tag = luaL_checkstring(L, i);
466  if (!InputTag::isInputTag(input_tag)) {
467  luaL_error(L, "'%s' is not a valid InputTag", input_tag.c_str());
468  }
469  callback->onIntegrandDeclared(InputTag::fromString(input_tag));
470  }
471 
472  return 0;
473  }
474 
475  int add_integration_dimension(lua_State* L) {
476  int n = lua_gettop(L);
477  if (n != 0) {
478  luaL_error(L, "invalid number of arguments: 0 expected, got %d", n);
479  }
480 
481  // Create input tag using current value of the index
482  int64_t cuba_index = lua_tonumber(L, lua_upvalueindex(1));
483  lua_pushnumber(L, cuba_index + 1);
484  lua_replace(L, lua_upvalueindex(1));
485 
486  std::string index_tag = "cuba::ps_points/";
487  index_tag += std::to_string(cuba_index);
488 
489  // Input tag is return value of the function
490  push_any(L, index_tag);
491 
492  // Add an integration dimension in the configuration
493  void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(2));
494  ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
495  callback->addIntegrationDimension();
496 
497  return 1;
498  }
499 
508  int declare_input(lua_State* L) {
509  int n = lua_gettop(L);
510  if (n != 1) {
511  luaL_error(L, "invalid number of arguments: 1 expected, got %d", n);
512  }
513 
514  std::string input_name = luaL_checkstring(L, 1);
515 
516  void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(1));
517  ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
518 
519  callback->onNewInputDeclared(input_name);
520 
521  return 0;
522  }
523 
524  void setup_hooks(lua_State* L, void* ptr) {
525  lua_pushlightuserdata(L, ptr);
526  lua_pushcclosure(L, load_modules, 1);
527  lua_setglobal(L, "load_modules");
528 
529  lua_pushlightuserdata(L, ptr);
530  lua_pushcclosure(L, parameter, 1);
531  lua_setglobal(L, "parameter");
532 
533  // Define the `add_dimension()` function in Lua and make it available in the global namespace.
534  // See add_integration_dimension for more information.
535  lua_pushnumber(L, 1);
536  lua_pushlightuserdata(L, ptr);
537  lua_pushcclosure(L, add_integration_dimension, 2);
538  lua_setglobal(L, "add_dimension");
539 
540  // integrand() function
541  lua_pushlightuserdata(L, ptr);
542  lua_pushcclosure(L, set_final_module, 1);
543  lua_setglobal(L, "integrand");
544 
545  // momemta_declare_input function
546  lua_pushlightuserdata(L, ptr);
547  lua_pushcclosure(L, declare_input, 1);
548  lua_setglobal(L, "momemta_declare_input");
549 
550  // C++ -> lua bindings of some classes
551  path_register(L, ptr);
552  }
553 
554  std::shared_ptr<lua_State> init_runtime(ILuaCallback* callback) {
555 
556  std::shared_ptr<lua_State> L(luaL_newstate(), lua_close);
557  luaL_openlibs(L.get());
558 
559  // Register hooks function, like `load_modules`
560  lua::setup_hooks(L.get(), callback);
561 
562  // Register existing modules
563  lua::register_modules(L.get(), callback);
564 
565  // Default functions
566  // Function defined by `embedLua.py` at build-time
567  execute_embed_lua_code(L.get());
568 
569  return L;
570  }
571 
572  void inject_parameters(lua_State* L, const ParameterSet& parameters) {
573  for (const auto& parameter: parameters.getNames()) {
574  LOG(debug) << "Injecting parameter " << parameter;
575  lua::push_any(L, parameters.rawGet(parameter));
576  lua_setglobal(L, parameter.c_str());
577  }
578  }
579 
580  namespace debug {
581  std::vector<std::string> dump_stack(lua_State *L) {
582  std::vector<std::string> stack;
583  for (int i = 1; i < lua_gettop(L) + 1; i++) {
584  if (lua_isnumber(L, i)) {
585  stack.push_back("number : " + std::to_string(lua_tonumber(L, i)));
586  } else if (lua_isstring(L, i)) {
587  stack.push_back(std::string("string : ") + std::string(luaL_checkstring(L, i)));
588  } else if (lua_istable(L, i)) {
589  stack.push_back("table");
590  } else if (lua_iscfunction(L, i)) {
591  stack.push_back("cfunction");
592  } else if (lua_isfunction(L, i)) {
593  stack.push_back("function");
594  } else if (lua_isboolean(L, i)) {
595  if (lua_toboolean(L, i) != 0)
596  stack.push_back("boolean: true");
597  else
598  stack.push_back("boolean: false");
599  } else if (lua_isuserdata(L, i)) {
600  stack.push_back("userdata");
601  } else if (lua_isnil(L, i)) {
602  stack.push_back("nil");
603  } else if (lua_islightuserdata(L, i)) {
604  stack.push_back("lightuserdata");
605  }
606  }
607 
608  return stack;
609  }
610 
611  void print_stack(lua_State* L) {
612  auto stack = dump_stack(L);
613  size_t index = 0;
614  LOG(debug) << "Stack has " << stack.size() << " elements: ";
615  for (const auto& e: stack) {
616  LOG(debug) << " #" << index++ << ": " << e;
617  }
618  }
619  }
620 }
virtual void onModuleDeclared(const std::string &type, const std::string &name)=0
A module is declared in the configuration file.
Notification callback used for communication between the lua file and MoMEMta.
Definition: ILuaCallback.h:30
LazyFunction(lua_State *L, int index)
Bind a anonymous lua function.
Definition: utils.cc:45
size_t get_index(lua_State *L, int index)
Convert a negative lua stack index to an absolute index.
Definition: utils.cc:116
void setup_hooks(lua_State *L, void *ptr)
Register all C function in the lua userspace.
Definition: utils.cc:524
std::shared_ptr< lua_State > init_runtime(ILuaCallback *callback)
Initialize the lua runtime.
Definition: utils.cc:554
Lua binding of C++ Path class.
Type lua_array_unique_type(lua_State *L, int index)
Check if a lua table contains only value from the same type.
Definition: utils.cc:158
Type type(lua_State *L, int index)
Extract the type of a lua value.
Definition: utils.cc:81
int load_modules(lua_State *L)
Hook for the load_modules lua function. The stack must have one element:
Definition: utils.cc:422
int declare_input(lua_State *L)
The configuration file declared a new input.
Definition: utils.cc:508
int lua_is_array(lua_State *L, int index)
Check if a lua table is an array.
Definition: utils.cc:125
static ModuleRegistry & get()
A singleton available at startup.
static void parse(ParameterSet &p, lua_State *L, int index)
Convert a lua table to a ParameterSet.
Type
List of all supported lua types.
Definition: utils.h:39
int ref_index
The reference index where the anonymous function is stored.
Definition: utils.h:73
lua_State * L
The global lua state. This state must be valid for as long as this instance.
Definition: utils.h:55
An identifier of a module&#39;s output.
Definition: InputTag_fwd.h:37
momemta::any to_vector(lua_State *L, int index, Type type)
Convert a lua array to a typed vector, encapsulated into a momemta::any.
Definition: utils.cc:297
void register_modules(lua_State *L, void *ptr)
Register modules in lua userspace.
Definition: utils.cc:376
void push_any(lua_State *L, const momemta::any &value)
Convert a momemta::any to a lua type, and push it to the top of the stack.
Definition: utils.cc:271
Generic functions to deal with custom lua types.
Lazy value in lua (delayed evaluation)
Definition: utils.h:54
A class encapsulating a lua table.
Definition: ParameterSet.h:82
int module_table_newindex(lua_State *L)
Hook for the metatable __newindex of the module&#39;s table.
Definition: utils.cc:332
virtual momemta::any operator()() const override
Evaluate the anonymous function.
Definition: utils.cc:55
void path_register(lua_State *L, void *ptr)
Register Path into lua runtime.
Definition: Path.cc:28
virtual void addIntegrationDimension()=0
A new integration dimension is requested in the configuration file.
< Thrown if the configuration file is not valid
Definition: utils.h:24
virtual void onNewInputDeclared(const std::string &name)=0
The configuration file declared a new input.
momemta::any get_custom_type_ptr(lua_State *L, int index)
Convert a lua custom table to a momemta::any value.
Definition: Types.cc:48
virtual void onIntegrandDeclared(const InputTag &tag)=0
The integrand was defined in the configuration file.
void exportList(bool ignore_internal, ModuleList &list) const
static InputTag fromString(const std::string &tag)
Create a input tag from its string representation.
Definition: InputTag.cc:75
int parameter(lua_State *L)
Hook for the parameter lua function. This function accepts one argument:
Definition: utils.cc:438
std::pair< momemta::any, bool > to_any(lua_State *L, int index)
Convert a lua type to momemta::any.
Definition: utils.cc:191
void inject_parameters(lua_State *L, const ParameterSet &parameters)
Inject parameters into the current lua state.
Definition: utils.cc:572
static bool isInputTag(const std::string &tag)
Check if a given string represent an InputTag.
Definition: InputTag.cc:54
const momemta::any & rawGet(const std::string &name) const
Definition: ParameterSet.cc:31
Utility functions related to lua configuration file parsing.
Definition: Path.h:31