//    MCL: MiMiC Communication Library
//    Copyright (C) 2015-2025  The MiMiC Authors (see CONTRIBUTORS file for details).
//
//    This file is part of MCL.
//
//    MCL is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Lesser General Public License as
//    published by the Free Software Foundation, either version 3 of
//    the License, or (at your option) any later version.
//
//    MCL is distributed in the hope that it will be useful, but
//    WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public License
//    along with this program.  If not, see <http://www.gnu.org/licenses/>.

#include "main_class.h"

#include <memory>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "data_types.h"
#include "error_codes.h"

using namespace mcl;
using ::testing::Exactly;
using ::testing::AtMost;
using ::testing::_;
using ::testing::Return;
using ::testing::ContainsRegex;
using ::testing::internal::CaptureStderr;
using ::testing::internal::GetCapturedStderr;

class MockTransport : public Transport {
  public:
    MockTransport() : Transport() {}
    ~MockTransport() override = default;

    MOCK_METHOD(int,  Initialize,       (void *args), (override));
    MOCK_METHOD(int,  Finalize,         (), (override));
    MOCK_METHOD(int,  Abort,            (int error_code), (override));
    MOCK_METHOD(int,  Send,             (void* data, int length, int data_type, int tag, int destination), (override));
    MOCK_METHOD(int,  Receive,          (void* data, int length, int data_type, int tag, int source), (override));
    MOCK_METHOD(int,  ProbeSize,        (int data_type, int source), (override));

    MOCK_METHOD(bool, is_server,        (), (const));
    MOCK_METHOD(bool, is_client,        (), (const));
    MOCK_METHOD(int,  num_programs,     (), (const));
    MOCK_METHOD(int,  program_id,       (), (const));
};

static constexpr int kSendTag = 16;
static constexpr int kReceiveTag = 17;
std::string uninitialized_error = "It seems that MCL was not initialized properly.";

int ReceiveSource(void* data, int count, int data_type, int tag, int source) {
    *static_cast<int*>(data) = source;
    return 0;
}


TEST(Uninitialized, IsInitialized) {
    ASSERT_TRUE(!MclMain::GetInstance().is_initialized());
}

TEST(Uninitialized, Send) {
    int buffer;
    ASSERT_DEATH(MclMain::GetInstance().Send(&buffer, 1, 0, 0, 0), ContainsRegex(uninitialized_error));
}

TEST(Uninitialized, Receive) {
    int buffer;
    ASSERT_DEATH(MclMain::GetInstance().Receive(&buffer, 1, 0, 0, 0), ContainsRegex(uninitialized_error));
}

TEST(Uninitialized, Abort) {
    ASSERT_DEATH(MclMain::GetInstance().Abort(std::rand()), ContainsRegex(uninitialized_error));
}

TEST(Uninitialized, ProbeSize) {
    ASSERT_DEATH(MclMain::GetInstance().ProbeSize(kTypeInt, 0), ContainsRegex(uninitialized_error));
}

TEST(Uninitialized, GetNumPrograms) {
    ASSERT_DEATH(MclMain::GetInstance().num_programs(), ContainsRegex(uninitialized_error));
}

TEST(Uninitialized, GetProgramId) {
    ASSERT_DEATH(MclMain::GetInstance().program_id(), ContainsRegex(uninitialized_error));
}

TEST(Uninitialized, Finalize) {
    ASSERT_DEATH(MclMain::GetInstance().Finalize(), ContainsRegex(uninitialized_error));
}


class ServerTest : public ::testing::Test {
  protected:
    void ExpectEndpointTypeInquiry(std::shared_ptr<MockTransport>& protocol_, int max_times) {
        EXPECT_CALL(*protocol_, is_server()).Times(AtMost(max_times)).WillRepeatedly(Return(is_server_));
        EXPECT_CALL(*protocol_, is_client()).Times(AtMost(max_times)).WillRepeatedly(Return(!is_server_));
    }

    void SetUp() override {
        MclMain::GetInstance().set_protocol(protocol_);
        EXPECT_CALL(*protocol_, Initialize(nullptr)).Times(Exactly(1));
        ExpectEndpointTypeInquiry(protocol_, 2);
        EXPECT_CALL(*protocol_, program_id()).Times(Exactly(1)).WillOnce(Return(0));
        EXPECT_CALL(*protocol_, num_programs()).Times(Exactly(1)).WillOnce(Return(num_programs_));
        for (int i = 1; i < num_programs_; ++i)
            EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, i)).Times(Exactly(1)).WillOnce(ReceiveSource);
        MclMain::GetInstance().Initialize(nullptr);
    }
    
    void TearDown() override {
        EXPECT_CALL(*protocol_, Finalize()).Times(Exactly(1));
        MclMain::GetInstance().Finalize();
        ASSERT_TRUE(!MclMain::GetInstance().is_initialized());
    }

    std::shared_ptr<MockTransport> protocol_ = std::make_shared<MockTransport>();
    int num_programs_ = 5;
    bool is_server_ = true;
    
};

TEST_F(ServerTest, Initialize) {
    const std::shared_ptr<Endpoint> endpoint = MclMain::GetInstance().endpoint();
    ASSERT_EQ(endpoint->id(), 0);
    int size = static_cast<int>(endpoint->location_ids().size());
    ASSERT_EQ(size, num_programs_-1);
    for (int i = 0; i < size; ++i)
        ASSERT_EQ(i+1, endpoint->location_ids()[i]);
    ASSERT_TRUE(MclMain::GetInstance().is_initialized());
}

TEST_F(ServerTest, Receive) {
    int test_data[5];
    ExpectEndpointTypeInquiry(protocol_, num_programs_-1);
    const std::shared_ptr<Endpoint> endpoint = MclMain::GetInstance().endpoint();
    int size = static_cast<int>(endpoint->location_ids().size());
    for (int i = 0; i < size; ++i)
        EXPECT_CALL(*protocol_, Receive(test_data, 5, kTypeInt, kReceiveTag, i+1)).Times(Exactly(1));
    for (int i = 0; i < size; ++i)
        MclMain::GetInstance().Receive(test_data, 5, kTypeInt, kReceiveTag, i+1);
}

TEST_F(ServerTest, Send) {
    int test_data[5];
    ExpectEndpointTypeInquiry(protocol_, num_programs_-1);
    const std::shared_ptr<Endpoint> endpoint = MclMain::GetInstance().endpoint();
    int size = static_cast<int>(endpoint->location_ids().size());
    for (int i = 0; i < size; ++i)
        EXPECT_CALL(*protocol_, Send(test_data, 5, kTypeInt, kSendTag, i+1)).Times(Exactly(1));
    for (int i = 0; i < size; ++i)
        MclMain::GetInstance().Send(test_data, 5, kTypeInt, kSendTag, i+1);
}

TEST_F(ServerTest, ProbeSize) {
    ExpectEndpointTypeInquiry(protocol_, num_programs_-1);
    const std::shared_ptr<Endpoint> endpoint = MclMain::GetInstance().endpoint();
    int size = static_cast<int>(endpoint->location_ids().size());
    for (int i = 0; i < size; ++i)
        EXPECT_CALL(*protocol_, ProbeSize(kTypeInt, i+1)).Times(Exactly(1)).WillOnce(Return(i));
    for (int i = 0; i < size; ++i)
        ASSERT_EQ(MclMain::GetInstance().ProbeSize(kTypeInt, i+1), i);
}


class ClientTest : public ServerTest {
  protected:
    void SetUp() override {
        is_server_ = false;
        MclMain::GetInstance().set_protocol(protocol_);
        EXPECT_CALL(*protocol_, Initialize(nullptr)).Times(Exactly(1));
        ExpectEndpointTypeInquiry(protocol_, 2);
        EXPECT_CALL(*protocol_, program_id()).Times(Exactly(1)).WillOnce(Return(client_id_));
        EXPECT_CALL(*protocol_, Send(_, 1, kTypeInt, _, 0)).Times(Exactly(1));
        MclMain::GetInstance().Initialize(nullptr);
    }
    
    int client_id_ = 4;
};

TEST_F(ClientTest, Initialize) {
    const std::shared_ptr<Endpoint> endpoint = MclMain::GetInstance().endpoint();
    ASSERT_EQ(endpoint->id(), client_id_);
    ASSERT_EQ(endpoint->location_ids().size(), 1u);
    ASSERT_EQ(endpoint->location_ids()[0], 0);
}

TEST_F(ClientTest, Send) {
    int test_data[5];
    ExpectEndpointTypeInquiry(protocol_, 1);
    EXPECT_CALL(*protocol_, Send(test_data, 5, kTypeInt, kSendTag, 0)).Times(Exactly(1));
    MclMain::GetInstance().Send(test_data, 5, kTypeInt, kSendTag, 0);
}

TEST_F(ClientTest, Receive) {
    int test_data[5];
    ExpectEndpointTypeInquiry(protocol_, 1);
    EXPECT_CALL(*protocol_, Receive(test_data, 5, kTypeInt, kReceiveTag, 0)).Times(Exactly(1));
    MclMain::GetInstance().Receive(test_data, 5, kTypeInt, kReceiveTag, 0);
}

TEST_F(ClientTest, Abort) {
    EXPECT_CALL(*protocol_, Abort(kMclAlreadyInit)).Times(Exactly(1)).WillOnce(Return(kMclAlreadyInit));
    MclMain::GetInstance().Abort(kMclAlreadyInit);
}


class InitializeFailTest : public ServerTest {
  protected:
    void SetUp() override {
        MclMain::GetInstance().set_protocol(protocol_);
    };
    
    void TearDown() override { };

    int FAIL_STATUS = 1;
};


TEST_F(InitializeFailTest, AlreadyInitialized) {
    EXPECT_CALL(*protocol_, Initialize(nullptr)).Times(Exactly(1));
    ExpectEndpointTypeInquiry(protocol_, 2);
    EXPECT_CALL(*protocol_, program_id()).Times(Exactly(1)).WillOnce(Return(0));
    EXPECT_CALL(*protocol_, num_programs()).Times(Exactly(1)).WillOnce(Return(num_programs_));
    for (int i = 1; i < num_programs_; ++i) {
        EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, i)).Times(Exactly(1)).WillOnce(ReceiveSource);
    }
    MclMain::GetInstance().Initialize(nullptr);

    EXPECT_CALL(*protocol_, Abort(kMclAlreadyInit)).Times(Exactly(1)).WillOnce(Return(kMclAlreadyInit));
    CaptureStderr();
    ASSERT_EQ(MclMain::GetInstance().Initialize(nullptr), kMclAlreadyInit);
    std::string error = GetCapturedStderr();
    ASSERT_THAT(error, ContainsRegex("~~~ MCL ~ ERROR ~~~"));
    ASSERT_THAT(error, ContainsRegex("MCL_Initialized is called, but MCL is already initialized!"));
    
    EXPECT_CALL(*protocol_, Finalize()).Times(Exactly(1));
    MclMain::GetInstance().Finalize();
}

TEST_F(InitializeFailTest, ProtocolInitialize) {
    EXPECT_CALL(*protocol_, Initialize(_)).Times(Exactly(1)).WillOnce(Return(FAIL_STATUS));
    EXPECT_CALL(*protocol_, Abort(FAIL_STATUS)).Times(Exactly(1)).WillOnce(Return(FAIL_STATUS));
    ASSERT_EQ(MclMain::GetInstance().Initialize(nullptr), FAIL_STATUS);
}

TEST_F(InitializeFailTest, EndpointInitialize) {
    EXPECT_CALL(*protocol_, Initialize(_)).Times(Exactly(1)).WillOnce(Return(kSuccess));
    ExpectEndpointTypeInquiry(protocol_, 2);
    EXPECT_CALL(*protocol_, program_id()).Times(Exactly(1)).WillOnce(Return(0));
    EXPECT_CALL(*protocol_, num_programs()).Times(Exactly(1)).WillOnce(Return(num_programs_));
    EXPECT_CALL(*protocol_, Receive(_, 1, kTypeInt, _, 1)).Times(Exactly(1)).WillOnce(Return(FAIL_STATUS));
    EXPECT_CALL(*protocol_, Abort(FAIL_STATUS)).Times(Exactly(1)).WillOnce(Return(FAIL_STATUS));
    ASSERT_EQ(MclMain::GetInstance().Initialize(nullptr), FAIL_STATUS);
}

int main(int argc, char** argv) {
    ::testing::InitGoogleMock(&argc, argv);
    auto result = RUN_ALL_TESTS();
    return result;
}
