//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "mpi_port_transport.h"

#include <algorithm>
#include <cmath>
#include <filesystem>
#include <fstream>
#include <future>
#include <string>
#include <thread>

#include "mpi.h"

#include "data_types.h"
#include "env_controls.h"
#include "error_codes.h"
#include "error_print.h"

#define CHECK_MPI(FUNCTION) { \
            int result = FUNCTION; \
            if (result != MPI_SUCCESS) { \
                PrintErrorMessage("There has been a problem in the MPI Port communication protocol!", \
                                  __FILE__, __LINE__, result); \
                return kErrMpiInitFail; \
            }\
        }

namespace mcl {

int MpiPortTransport::Initialize(void *args) {
    // set timeout
    timeout_ = env_timeout_value.value();
    if (env_timeout_value.found())
        if (!env_timeout_value.IsDefined()) return kErrInvalidEnvVar;
    if (std::signbit(timeout_)) {
        PrintWarning("The requested environment variable, " + AddQuotes(env_timeout_value.name()) +
                     " has a negative value!\n" "Using the default value instead!");
        timeout_ = env_timeout_value.default_value();;
    }

    // read program number
    program_id_ = env_program_number.value();
    if (!env_program_number.IsDefined()) return kErrInvalidEnvVar;

    // run initialize
    std::future<int> request = std::async(std::launch::async, &MpiPortTransport::InnerInitialize, this);
    if (request.wait_for(std::chrono::milliseconds(timeout_)) == std::future_status::timeout) {
        PrintErrorMessage("Failed to initialize! Initialize in program no. " + std::to_string(program_id_) +
                          " timed out!", __FILE__, __LINE__);
        return kErrMpiTimeout;
    }
    return request.get();
}
    
int MpiPortTransport::InnerInitialize() {
    namespace fs = std::filesystem;
    
    int local_rank;
    MPI_Comm_rank(local_comm_, &local_rank);
    
    constexpr auto kPortFileName = ".MCL_portname";
    constexpr auto kNumProgFileName = ".MCL_numprogs";

    std::string port_directory = env_port_directory.value();
    if (!env_port_directory.IsDefined()) return kErrInvalidEnvVar;

    std::string port_file = port_directory + "/" + kPortFileName + "_";
    std::string num_prog_file = port_directory + "/" + kNumProgFileName;
    std::string barrier_file = port_directory + "/" + ".MCL_write_portfiles";

    // determine which process is the server and which are clients
    if (local_rank == 0) {
        is_server_ = (program_id_ == 0);
        is_client_ = (program_id_ != 0);
    }

    // count the number of programs in simulation
    if (is_server_) {
        // remove potentially undeleted files from previous simulations
        fs::remove(num_prog_file);
        for (auto &file : fs::directory_iterator(port_directory)) {
            std::string file_name = file.path().filename().string();
            if (file_name.find(kPortFileName) == 0)
                fs::remove(file);
        }
        
        // create a barrier file to instruct clients to create port files 
        std::this_thread::sleep_for(std::chrono::milliseconds(250));
        std::fstream f(barrier_file, std::fstream::out);
        if (f.is_open()) f.close();

        // count the total number of port files, i.e., the number of programs
        num_programs_ = -1;
        int client_count = 0;
        while (num_programs_ != client_count) {
            num_programs_ = client_count;
            client_count = 0;
            std::this_thread::sleep_for(std::chrono::milliseconds(250));
            for (auto &file : fs::directory_iterator(port_directory)) {
                std::string file_name = file.path().filename().string();
                if (file_name.find(kPortFileName) == 0)
                    client_count++;
            }
        }
        ++num_programs_;
        
        // write the total number of programs to a file to share with clients
        f.open(num_prog_file, std::fstream::out);
        if (f.is_open()) {
            f << std::to_string(num_programs_);
            f.close();
        }
    }
    if (is_client_) {
        port_file += std::to_string(program_id_);

        // wait for the signal from server (barrier file created)
        while (!fs::exists(barrier_file));

        // write the port file
        std::fstream f(port_file, std::fstream::out);
        if (f.is_open()) f.close();
        
        // wait for the signal from server (barrier file deleted)
        while (fs::exists(barrier_file));
        
        // get the total number of programs from the server
        f.open(num_prog_file, std::fstream::in);
        if (f.good()) {
            f >> num_programs_;
            f.close();
        }
    }
    CHECK_MPI(MPI_Bcast(&num_programs_, 1, MPI_INT, 0, local_comm_));


    // create intercommunicators between the server and clients
    std::vector<MPI_Comm> intercomms;
    std::vector<std::string> port_names;
    if (is_server_) {
        // open ports for client programs and store them in the port files
        char port_name[MPI_MAX_PORT_NAME];
        intercomms.push_back(MPI_COMM_SELF);
        port_names.push_back(std::string());
        for (int i = 1; i < num_programs_; ++i) {
            CHECK_MPI(MPI_Open_port(MPI_INFO_NULL, port_name));
            port_names.push_back(port_name);
            intercomms.push_back(MPI_COMM_NULL);
            std::string port_file_tmp = port_file + std::to_string(i);
            std::fstream f(port_file_tmp, std::fstream::out);
            if (f.is_open()) {
                f << port_name;
                f.close();
            }
        }
        
        // remove the barrier file to instruct clients to read the port files
        fs::remove(barrier_file);
        
        // check that the port files were deleted and remove the file with num_programs
        bool deleted = false;
        while (!deleted) {
            deleted = true;
            for (auto &file : fs::directory_iterator(port_directory)) {
                std::string file_name = file.path().filename().string();
                if (file_name.find(kPortFileName) == 0)
                    deleted = false;
            }
        }
        fs::remove(num_prog_file);
    
        // connect to clients
        for (int i = 1; i < num_programs_; ++i)
            CHECK_MPI(MPI_Comm_accept(port_names[i].c_str(), MPI_INFO_NULL, 0, MPI_COMM_SELF, &intercomms[i]));
    }

    if (is_client_) {
        // read the port name from the respective port file
        char port_name[MPI_MAX_PORT_NAME];
        std::fstream f(port_file, std::fstream::in);
        if (f.good()) {
            f >> port_name;
            f.close();
        }
        fs::remove(port_file);

        // connect to the server
        intercomms.push_back(MPI_Comm{});
        CHECK_MPI(MPI_Comm_connect(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &intercomms[0]));
    }


    // merge communicators into a single communicator containing all processes
    int kPortTag = 11235;
    bool open_port = (program_id_ == 0);
    for (int i = 1; i < num_programs_; ++i) {
        char port_name[MPI_MAX_PORT_NAME];
        MPI_Comm intercomm;
        if (open_port) {
            // open port and send its name
            CHECK_MPI(MPI_Open_port(MPI_INFO_NULL, port_name));
            if (is_server_) {
                CHECK_MPI(MPI_Send(&port_name, MPI_MAX_PORT_NAME, MPI_CHAR, 0, kPortTag, intercomms[i]));
                CHECK_MPI(MPI_Comm_disconnect(&intercomms[i]));
                CHECK_MPI(MPI_Close_port(port_names[i].c_str()));
            }
            CHECK_MPI(MPI_Barrier(world_comm_));

            // accept connection and merge
            CHECK_MPI(MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, world_comm_, &intercomm));
            CHECK_MPI(MPI_Intercomm_merge(intercomm, 0, &world_comm_));
        } else if (program_id_ == i) {
            // receive port name
            if (is_client_) {
                MPI_Status status;
                CHECK_MPI(MPI_Recv(&port_name, MPI_MAX_PORT_NAME, MPI_CHAR, 0, kPortTag, intercomms[0], &status));
                CHECK_MPI(MPI_Comm_disconnect(&intercomms[0]));
            }
            CHECK_MPI(MPI_Bcast(&port_name, MPI_MAX_PORT_NAME, MPI_CHAR, 0, local_comm_));

            // connect and merge
            CHECK_MPI(MPI_Comm_connect(port_name, MPI_INFO_NULL, 0, local_comm_, &intercomm));
            CHECK_MPI(MPI_Intercomm_merge(intercomm, 1, &world_comm_));
        }

        // cleanup
        if (open_port || program_id_ == i) {
            CHECK_MPI(MPI_Bcast(&open_port, 1, MPI_CXX_BOOL, 0, world_comm_));
            CHECK_MPI(MPI_Comm_disconnect(&intercomm));
            if (open_port) CHECK_MPI(MPI_Close_port(port_name));
        }
    }
 

    // create an intercommunicator
    int world_rank;
    MPI_Comm_rank(world_comm_, &world_rank);
    std::vector<int> intercomm_ranks(num_programs_, 0);
    intercomm_ranks[program_id_] = (local_rank == 0) ? world_rank : 0;
    CHECK_MPI(MPI_Allreduce(MPI_IN_PLACE, intercomm_ranks.data(), num_programs_, MPI_INT, MPI_SUM, world_comm_));

    if (local_rank == 0) {
        MPI_Group world;
        MPI_Comm_group(world_comm_, &world);
        MPI_Group intercomm_group;
        CHECK_MPI(MPI_Group_incl(world, num_programs_, intercomm_ranks.data(), &intercomm_group));
        CHECK_MPI(MPI_Comm_create_group(world_comm_, intercomm_group, 1, &intercomm_));
    }

    return kSuccess;
}

} // namespace mcl
