Loading [MathJax]/extensions/tex2jax.js

Search Results

 /*
  *  MoMEMta: a modular implementation of the Matrix Element Method
  *  Copyright (C) 2016  Universite catholique de Louvain (UCL), Belgium
  *
  *  This program is free software: you can redistribute it and/or modify
  *  it under the terms of the GNU General Public License as published by
  *  the Free Software Foundation, either version 3 of the License, or
  *  (at your option) any later version.
  *
  *  This program is distributed in the hope that it will be useful,
  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  *  GNU General Public License for more details.
  *
  *  You should have received a copy of the GNU General Public License
  *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
 #include <lua/utils.h>
 
 #include <cassert>
 #include <regex>
 
 #include <momemta/InputTag.h>
 #include <momemta/ILuaCallback.h>
 #include <momemta/Logging.h>
 #include <momemta/ModuleRegistry.h>
 #include <momemta/ParameterSet.h>
 #include <momemta/Utils.h>
 
 #include <LibraryManager.h>
 #include <lua/ParameterSetParser.h>
 #include <lua/bindings/Path.h>
 #include <lua/bindings/Types.h>
 
 // Defined by `embedLua.py` at build-time
 extern void execute_embed_lua_code(lua_State*);
 
 namespace lua {
 
     Lazy::Lazy(lua_State* L) {
         this->L = L;
     }
 
     LazyFunction::LazyFunction(lua_State* L, int index): Lazy(L) {
         auto absolute_index = get_index(L, index);
 
         // Duplicate the function on the top of the stack. This ensure the stack size won't change
         lua_pushvalue(L, absolute_index);
 
         // Pop the anonymous function from the stack, and store it in the global lua registry
         ref_index = luaL_ref(L, LUA_REGISTRYINDEX);
     }
 
     momemta::any LazyFunction::operator() () const {
 
         LOG(trace) << "[LazyFunction::operator()] >> stack size = " << lua_gettop(L);
 
         // Pop the anonymous function from the registry, and push it on the top of the stack
         lua_rawgeti(L, LUA_REGISTRYINDEX, ref_index);
 
         // Call the function. The function removed from the stack, and the return value is pushed on the top of the stack
         auto result = lua_pcall(L, 0, 1, 0);
         if (result != LUA_OK) {
             std::string error = lua_tostring(L, -1);
             LOG(fatal) << "Fail to call lua anonymous function. Return value is " << result << ". Error message: " << error;
         }
 
         momemta::any value;
         bool lazy = false;
         std::tie(value, lazy) = to_any(L, -1);
         assert(!lazy);
 
         lua_pop(L, 1);
 
         LOG(trace) << "[LazyFunction::operator()] << stack size = " << lua_gettop(L);
 
         return value;
     }
 
     Type type(lua_State* L, int index) {
         int t = lua_type(L, index);
 
         switch (t) {
             case LUA_TBOOLEAN:
                 return BOOLEAN;
                 break;
 
             case LUA_TSTRING: {
                 std::string value = lua_tostring(L, index);
                 if (InputTag::isInputTag(value))
                     return INPUT_TAG;
                 else
                     return STRING;
             } break;
 
             case LUA_TNUMBER:
                 if (lua_isinteger(L, index))
                     return INTEGER;
                 else
                     return REAL;
                 break;
 
             case LUA_TTABLE: {
                 // We only support ParameterSet for table
                 if (lua_is_array(L, index) == -1) {
                     return PARAMETER_SET;
                 }
 
             } break;
         }
 
         return NOT_SUPPORTED;
     }
 
     size_t get_index(lua_State* L, int index) {
         return (index < 0) ? lua_gettop(L) + index + 1 : index;
     }
 
     int lua_is_array(lua_State* L, int index) {
 
         LOG(trace) << "[lua_is_array] >> stack size = " << lua_gettop(L);
 
         size_t table_index = get_index(L, index);
 
         if (lua_type(L, table_index) != LUA_TTABLE)
             return -1;
 
         size_t size = 0;
 
         lua_pushnil(L);
         while (lua_next(L, table_index) != 0) {
             if (lua_type(L, -2) != LUA_TNUMBER) {
                 lua_pop(L, 2);
                 LOG(trace) << "[lua_is_array] << stack size = " << lua_gettop(L);
                 return -1;
             }
 
             size++;
 
             lua_pop(L, 1);
         }
 
         LOG(trace) << "[lua_is_array] << stack size = " << lua_gettop(L);
         return size;
     }
 
     Type lua_array_unique_type(lua_State* L, int index) {
 
         size_t absolute_index = get_index(L, index);
 
         if (lua_type(L, absolute_index) != LUA_TTABLE)
             return NOT_SUPPORTED;
 
         Type result = NOT_SUPPORTED;
 
         lua_pushnil(L);
         while (lua_next(L, absolute_index) != 0) {
             if (result == NOT_SUPPORTED) {
                 result = type(L, -1);
             } else {
                 Type entry_type = type(L, -1);
 
                 if ((result == INTEGER) && (entry_type == REAL))
                     result = REAL;
                 else if ((result == REAL) && (entry_type == INTEGER))
                     result = REAL;
                 else if (result != entry_type) {
                     lua_pop(L, 2);
                     result = NOT_SUPPORTED;
                     break;
                 }
             }
 
             lua_pop(L, 1);
         }
 
         return result;
     }
 
     std::pair<momemta::any, bool> to_any(lua_State* L, int index) {
 
         LOG(trace) << "[to_any] >> stack size = " << lua_gettop(L);
         size_t absolute_index = get_index(L, index);
 
         momemta::any result;
         bool lazy = false;
 
         auto type = lua_type(L, absolute_index);
         switch (type) {
             case LUA_TNUMBER: {
                 if (lua_isinteger(L, absolute_index)) {
                     int64_t number = lua_tointeger(L, absolute_index);
                     result = number;
                 } else {
                     double number = lua_tonumber(L, absolute_index);
                     result = number;
                 }
             } break;
 
             case LUA_TBOOLEAN: {
                 bool value = lua_toboolean(L, absolute_index);
                 result = value;
             } break;
 
             case LUA_TSTRING: {
                 std::string value = lua_tostring(L, absolute_index);
                 if (InputTag::isInputTag(value)) {
                     InputTag tag = InputTag::fromString(value);
                     result = tag;
                 } else {
                     result = value;
                 }
             } break;
 
             case LUA_TTABLE: {
                 LOG(trace) << "[to_any::table] >> stack size = " << lua_gettop(L);
                 if (lua::lua_is_array(L, absolute_index) > 0) {
 
                     Type type = NOT_SUPPORTED;
 
                     if ((type = lua::lua_array_unique_type(L, absolute_index)) != NOT_SUPPORTED) {
                         result = to_vector(L, absolute_index, type);
                     } else {
                         throw invalid_array_error("Various types stored into the array. This is not supported.");
                     }
 
                 } else {
                     ParameterSet cfg;
                     ParameterSetParser::parse(cfg, L, absolute_index);
                     result = cfg;
                 }
                 LOG(trace) << "[to_any::table] << stack size = " << lua_gettop(L);
             } break;
 
             case LUA_TFUNCTION: {
                 LOG(trace) << "[to_any::function] >> stack size = " << lua_gettop(L);
 
                 result = LazyFunction(L, absolute_index);
                 lazy = true;
 
                 LOG(trace) << "[to_any::function] << stack size = " << lua_gettop(L);
             } break;
 
             case LUA_TUSERDATA: {
                 result = get_custom_type_ptr(L, absolute_index);
 
             } break;
 
             default: {
                 LOG(fatal) << "Unsupported lua type: " << lua_type(L, absolute_index);
                 throw lua::invalid_configuration_file("");
             } break;
         }
 
         LOG(trace) << "[to_any] << final type = " << demangle(result.type().name());
         LOG(trace) << "[to_any] << stack size = " << lua_gettop(L);
         return {result, lazy};
     }
 
     void push_any(lua_State* L, const momemta::any& value) {
         LOG(trace) << "[push_any] >> stack size = " << lua_gettop(L);
 
         if (value.type() == typeid(int64_t)) {
             int64_t v = momemta::any_cast<int64_t>(value);
             lua_pushinteger(L, v);
         } else if (value.type() == typeid(double)) {
             double v = momemta::any_cast<double>(value);
             lua_pushnumber(L, v);
         } else if (value.type() == typeid(bool)) {
             bool v = momemta::any_cast<bool>(value);
             lua_pushboolean(L, v);
         } else if (value.type() == typeid(std::string)) {
             auto v = momemta::any_cast<std::string>(value);
             lua_pushstring(L, v.c_str());
         } else if (value.type() == typeid(InputTag)) {
             auto v = momemta::any_cast<InputTag>(value).toString();
             lua_pushstring(L, v.c_str());
         } else {
             LOG(fatal) << "Unsupported C++ value: " << demangle(value.type().name());
             throw lua::unsupported_type_error(demangle(value.type().name()));
         }
 
         LOG(trace) << "[push_any] << stack size = " << lua_gettop(L);
     }
 
     momemta::any to_vector(lua_State* L, int index, Type t) {
         switch (t) {
             case BOOLEAN:
                 return to_vectorT<bool>(L, index);
 
             case STRING:
                 return to_vectorT<std::string>(L, index);
 
             case INTEGER:
                 return to_vectorT<int64_t>(L, index);
 
             case REAL:
                 return to_vectorT<double>(L, index);
 
             case INPUT_TAG:
                 return to_vectorT<InputTag>(L, index);
 
             case PARAMETER_SET:
                 return to_vectorT<ParameterSet>(L, index);
 
             case NOT_SUPPORTED:
                 break;
         }
 
         throw invalid_array_error("Unsupported array type");
     }
 
     template<> double special_any_cast(const momemta::any& value) {
         if (value.type() == typeid(int64_t))
             return static_cast<double>(momemta::any_cast<int64_t>(value));
 
         return momemta::any_cast<double>(value);
     }
 
     int module_table_newindex(lua_State* L) {
         lua_getmetatable(L, 1);
         lua_getfield(L, -1, "__type");
 
         const char* module_type = luaL_checkstring(L, -1);
         const char* module_name = luaL_checkstring(L, 2);
 
         // Remove field name from stack
         lua_pop(L, 1);
 
         // Validate module name
         // Format is: [a-zA-Z][a-zA-Z0-9_]*
         static std::regex name_regex("[a-zA-Z][a-zA-Z0-9_]*");
 
         if (! std::regex_match(module_name, name_regex)) {
             luaL_error(L, "invalid module name '%s': valid format is [a-zA-Z][a-zA-Z0-9_]*", module_name);
         }
 
         lua_getfield(L, -1, "__ptr");
         void* cfg_ptr = lua_touserdata(L, -1);
         ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
 
         callback->onModuleDeclared(module_type, module_name);
 
         // Remove metatable from the stack
         lua_pop(L, 2);
 
         // Add "@name" and "@type" fields to the module's parameters
 
         // Push the key and then the value
         lua_pushstring(L, "@name");
         lua_pushstring(L, module_name);
         lua_rawset(L, -3);
 
         lua_pushstring(L, "@type");
         lua_pushstring(L, module_type);
         lua_rawset(L, -3);
 
         // And actually set the value to the table
         lua_rawset(L, 1);
 
         return 0;
     }
 
     void register_modules(lua_State* L, void* ptr) {
         momemta::ModuleList modules;
         momemta::ModuleRegistry::get().exportList(true, modules);
 
         for (const auto& module: modules) {
             const char* module_name = module.name.c_str();
 
             int type = lua_getglobal(L, module_name);
             lua_pop(L, 1);
             if (type != LUA_TNIL) {
                 // Global already exists
                 continue;
             }
 
             // Create a new empty table
             lua_newtable(L);
 
             std::string module_metatable = module.name + "_mt";
 
             // Create the associated metatable
             luaL_newmetatable(L, module_metatable.c_str());
 
             lua_pushstring(L, module_name);
             lua_setfield(L, -2, "__type");
 
             lua_pushlightuserdata(L, ptr);
             lua_setfield(L, -2, "__ptr");
 
             // Set the metatable '__newindex' function
             // This function is called when a assignment is made to the table
             // In our case, it's called when a new module is declared
             const luaL_Reg l[] = {
                 {"__newindex", lua::module_table_newindex},
                 {nullptr, nullptr}
             };
             luaL_setfuncs(L, l, 0);
 
             lua_setmetatable(L, -2);
 
             // And register it as a global variable
             lua_setglobal(L, module_name);
 
             LOG(trace) << "Registered new lua global variable '" << module_name << "'";
         }
     }
 
     int load_modules(lua_State* L) {
         int n = lua_gettop(L);
         if (n != 1) {
             luaL_error(L, "invalid number of arguments: 1 expected, got %d", n);
         }
 
         void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(1));
 
         const char *path = luaL_checkstring(L, 1);
         LibraryManager::get().registerLibrary(path);
 
         register_modules(L, cfg_ptr);
 
         return 0;
     }
 
     int parameter(lua_State* L) {
         int n = lua_gettop(L);
         if (n != 1) {
             luaL_error(L, "invalid number of arguments: 1 expected, got %d", n);
         }
 
         std::string parameter_name = luaL_checkstring(L, 1);
 
         // Create an anonymous function return the value of the parameter
         // Assumes there's a global table named `parameters`
 
         std::string code = "return function() return parameters['" + parameter_name + "'] end";
         luaL_dostring(L, code.c_str());
 
         return 1;
     }
 
     int set_final_module(lua_State* L) {
         int n = lua_gettop(L);
         if (n == 0) {
             luaL_error(L, "invalid number of arguments: at least one expected, got 0");
         }
 
         void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(1));
         ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
 
         for(size_t i = 1; i <= size_t(n); i++) {
             std::string input_tag = luaL_checkstring(L, i);
             if (!InputTag::isInputTag(input_tag)) {
                 luaL_error(L, "'%s' is not a valid InputTag", input_tag.c_str());
             }
             callback->onIntegrandDeclared(InputTag::fromString(input_tag));
         }
 
         return 0;
     }
 
     int add_integration_dimension(lua_State* L) {
         int n = lua_gettop(L);
         if (n != 0) {
             luaL_error(L, "invalid number of arguments: 0 expected, got %d", n);
         }
 
         // Create input tag using current value of the index
         int64_t cuba_index = lua_tonumber(L, lua_upvalueindex(1));
         lua_pushnumber(L, cuba_index + 1);
         lua_replace(L, lua_upvalueindex(1));
 
         std::string index_tag = "cuba::ps_points/";
         index_tag += std::to_string(cuba_index);
 
         // Input tag is return value of the function
         push_any(L, index_tag);
 
         // Add an integration dimension in the configuration
         void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(2));
         ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
         callback->addIntegrationDimension();
 
         return 1;
     }
 
     int declare_input(lua_State* L) {
         int n = lua_gettop(L);
         if (n != 1) {
             luaL_error(L, "invalid number of arguments: 1 expected, got %d", n);
         }
 
         std::string input_name = luaL_checkstring(L, 1);
 
         void* cfg_ptr = lua_touserdata(L, lua_upvalueindex(1));
         ILuaCallback* callback = static_cast<ILuaCallback*>(cfg_ptr);
 
         callback->onNewInputDeclared(input_name);
 
         return 0;
     }
 
     void setup_hooks(lua_State* L, void* ptr) {
         lua_pushlightuserdata(L, ptr);
         lua_pushcclosure(L, load_modules, 1);
         lua_setglobal(L, "load_modules");
 
         lua_pushlightuserdata(L, ptr);
         lua_pushcclosure(L, parameter, 1);
         lua_setglobal(L, "parameter");
 
         // Define the `add_dimension()` function in Lua and make it available in the global namespace.
         // See add_integration_dimension for more information.
         lua_pushnumber(L, 1);
         lua_pushlightuserdata(L, ptr);
         lua_pushcclosure(L, add_integration_dimension, 2);
         lua_setglobal(L, "add_dimension");
 
         // integrand() function
         lua_pushlightuserdata(L, ptr);
         lua_pushcclosure(L, set_final_module, 1);
         lua_setglobal(L, "integrand");
 
         // momemta_declare_input function
         lua_pushlightuserdata(L, ptr);
         lua_pushcclosure(L, declare_input, 1);
         lua_setglobal(L, "momemta_declare_input");
 
         // C++ -> lua bindings of some classes
         path_register(L, ptr);
     }
 
     std::shared_ptr<lua_State> init_runtime(ILuaCallback* callback) {
 
         std::shared_ptr<lua_State> L(luaL_newstate(), lua_close);
         luaL_openlibs(L.get());
 
         // Register hooks function, like `load_modules`
         lua::setup_hooks(L.get(), callback);
 
         // Register existing modules
         lua::register_modules(L.get(), callback);
 
         // Default functions
         // Function defined by `embedLua.py` at build-time
         execute_embed_lua_code(L.get());
 
         return L;
     }
 
     void inject_parameters(lua_State* L, const ParameterSet& parameters) {
         for (const auto& parameter: parameters.getNames()) {
             LOG(debug) << "Injecting parameter " << parameter;
             lua::push_any(L, parameters.rawGet(parameter));
             lua_setglobal(L, parameter.c_str());
         }
     }
 
     namespace debug {
         std::vector<std::string> dump_stack(lua_State *L) {
             std::vector<std::string> stack;
             for (int i = 1; i < lua_gettop(L) + 1; i++) {
                 if (lua_isnumber(L, i)) {
                     stack.push_back("number : " + std::to_string(lua_tonumber(L, i)));
                 } else if (lua_isstring(L, i)) {
                     stack.push_back(std::string("string : ") + std::string(luaL_checkstring(L, i)));
                 } else if (lua_istable(L, i)) {
                     stack.push_back("table");
                 } else if (lua_iscfunction(L, i)) {
                     stack.push_back("cfunction");
                 } else if (lua_isfunction(L, i)) {
                     stack.push_back("function");
                 } else if (lua_isboolean(L, i)) {
                     if (lua_toboolean(L, i) != 0)
                         stack.push_back("boolean: true");
                     else
                         stack.push_back("boolean: false");
                 } else if (lua_isuserdata(L, i)) {
                     stack.push_back("userdata");
                 } else if (lua_isnil(L, i)) {
                     stack.push_back("nil");
                 } else if (lua_islightuserdata(L, i)) {
                     stack.push_back("lightuserdata");
                 }
             }
 
             return stack;
         }
 
         void print_stack(lua_State* L) {
             auto stack = dump_stack(L);
             size_t index = 0;
             LOG(debug) << "Stack has " << stack.size() << " elements: ";
             for (const auto& e: stack) {
                 LOG(debug) << "  #" << index++ << ": " << e;
             }
         }
     }
 }
All Classes Namespaces Files Functions Variables Enumerations Enumerator Macros Modules Pages