/*
 * This file is part of ArmarX.
 *
 * ArmarX is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 *
 * ArmarX is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 *
 * @package    RobotAPI::ArmarXObjects::DMPComponent
 * @author     Mirko Waechter ( mirko dot waechter at kit dot edu )
 * @date       2015
 * @copyright  http://www.gnu.org/licenses/gpl-2.0.txt
 *             GNU General Public License
 */

#include "DMPComponent.h"
#include <MMM/Motion/MotionReaderXML.h>
#include <dmp/io/MMMConverter.h>
#include <dmp/testing/testdataset.h>

using namespace armarx;


void DMPComponent::onInitComponent()
{
    ARMARX_INFO << "initializing DMP component";
    usingProxy(getProperty<std::string>("LongtermMemoryName").getValue());
    ARMARX_INFO << "successfully initialized DMP component" ;
}


void DMPComponent::onConnectComponent()
{
    ARMARX_INFO << "connecting DMP component";

    try
    {
        longtermMemoryPrx = getProxy<memoryx::LongtermMemoryInterfacePrx>(getProperty<std::string>("LongtermMemoryName").getValue());
    }
    catch (...)
    {
        ARMARX_ERROR << "cannot get longterm memory proxy";
        return;
    }

    try
    {
        dmpDataMemoryPrx = longtermMemoryPrx->getDMPSegment();
    }
    catch (...)
    {
        ARMARX_ERROR << "cannot get dmp segment of longterm memory";
        return;
    }

    ARMARX_INFO << "successfully connected DMP component";
}


void DMPComponent::onDisconnectComponent()
{
    ARMARX_INFO << "disconnecting DMP component";
}


void DMPComponent::onExitComponent()
{
    ARMARX_INFO << "exiting DMP component";
}


void DMPComponent::instantiateDMP(int DMPType, const Ice::Current&)
{
    ARMARX_INFO << "instantiate DMP";
    dmpType = DMPType;

    switch (dmpType)
    {
        case ARMARX_DMPTYPE_BASICDMP:
            basicdmp.reset(new DMP::BasicDMP);
            break;

        case ARMARX_DMPTYPE_DMP3RDORDER:
            dmp3rdorder.reset(new DMP::DMP3rdOrder);
            break;

        case ARMARX_DMPTYPE_DMP3RDORDERFORCEFIELD:
            dmp3rdorder.reset(new DMP::DMP3rdOrderForceField);
            break;

        case ARMARX_DMPTYPE_ENDVELODMP:
            basicdmp.reset(new DMP::EndVeloDMP);
            break;

        case ARMARX_DMPTYPE_FORCEFIELDDMP:
            basicdmp.reset(new DMP::ForceFieldDMP);
            break;

        case ARMARX_DMPTYPE_SIMPLEENDVELODMP:
            basicdmp.reset(new DMP::SimpleEndVeloDMP);
            break;
        //    case ARMARX_DMPTYPE_ENDVELFORCEFILELDDMP: basicdmp.reset(new DMP::EndVeloForceFieldDMP);break;
        //    case ARMARX_DMPTYPE_QUATERNIONDMP: basicdmp.reset(new DMP::QuaternionDMP); break; //error: quanternionDMP is an abstract type.
        //    case ARMARX_DMPTYPE_ADAPTIVEGOAL3RDORDERDMP: basicdmp.reset(new DMP::AdaptiveGoal3rdOrderDMP); break; //error:
        case ARMARX_DMPTYPE_PERIODICTRANSIENTDMP:
            ARMARX_INFO << "instantiate PeriodicDMP";
            basicdmp.reset(new DMP::PeriodicTransientDMP);
            break;
        default:
            ARMARX_ERROR << "ERROR: It is not a valid dmp type. " ;
    }
    tCurrent = 0;
}


void DMPComponent::storeDMPInDatabase(const std::string& name, const ::Ice::Current&)
{

    ARMARX_INFO << "storing DMP in the database";
    memoryx::DMPEntityPtr dmpEntity = new memoryx::DMPEntity(name);

    dmpEntity->setDMPType(dmpType);
    dmpEntity->setDMPName(name);
    dmpEntity->set3rdOrder(is3rdOrder);

    std::stringstream dmptext;
    boost::archive::text_oarchive ar(dmptext);


    if (is3rdOrder)
    {
        dmpEntity->set3rdOrder(true);
        ar << boost::serialization::make_nvp("dmp", *dmp3rdorder);
    }
    else
    {
        dmpEntity->set3rdOrder(false);
        ar << boost::serialization::make_nvp("dmp", *basicdmp);
    }

    dmpEntity->setDMPtextStr(dmptext.str());

    std::cout << dmpEntity << std::endl;
    const std::string entityID = dmpDataMemoryPrx->addEntity(dmpEntity);
    dmpEntity->setId(entityID);

    ARMARX_INFO << "successfully stored DMP";

}


void DMPComponent::getDMPFromDatabase(const std::string& dmpName, const Ice::Current&)
{
    ARMARX_INFO << "getting DMP from database";

    if (!dmpDataMemoryPrx->hasEntityByName(dmpName))
    {
        ARMARX_ERROR << "DMP with name " + dmpName + " does not exist in the database";
        return;
    }

    memoryx::DMPEntityPtr dmpEntity = memoryx::DMPEntityPtr::dynamicCast(dmpDataMemoryPrx->getDMPEntityByName(dmpName));

    std::string name = dmpEntity->getDMPName();
    ARMARX_INFO << "DMP with name: " + name + " is loaded";

    dmpType = dmpEntity->getDMPType();
    ARMARX_INFO << "dmp type is " + std::to_string(dmpType);

    is3rdOrder = dmpEntity->get3rdOrder();
    ARMARX_INFO << "is3rdOrder is " + std::to_string(is3rdOrder);

    std::string textStr = dmpEntity->getDMPtextStr();
    ARMARX_INFO << textStr;
    std::stringstream istr;
    istr.str(textStr);

    boost::archive::text_iarchive ar(istr);

    if (is3rdOrder)
    {
        ar >> boost::serialization::make_nvp("dmp", *dmp3rdorder);
    }
    else
    {
        ar >> boost::serialization::make_nvp("dmp", *basicdmp);
    }

    ARMARX_INFO << "successfully got dmp from database.";
}


void DMPComponent::trainDMP(const ::Ice::Current&)
{
    // learn dmp
    ARMARX_INFO << "In Train DMP";
    if (dmpType == ARMARX_DMPTYPE_BASICDMP)
    {

        basicdmp->learnFromTrajectories(trajs);
    }
    else if (dmpType == ARMARX_DMPTYPE_DMP3RDORDER)
    {

        dmp3rdorder->learnFromTrajectories(trajs);
    }
    else if (dmpType == ARMARX_DMPTYPE_PERIODICTRANSIENTDMP)
    {


        basicdmp->learnFromTrajectories(trajs);
    }
    ARMARX_INFO << "Exit Train DMP";
}

void DMPComponent::calculateWholeTrajectory()
{

}

DMP::Vec<DMP::DMPState> DMPComponent::calculateNextState(DMP::Vec<DMP::DMPState>& initialStates, double t, double tInit, DMP::DVec& canonicalValues)
{

    if (configs.find(DMP_PARAMETERS_GOAL) == configs.end())
    {
        ARMARX_ERROR << "The goal of DMP must be specified";
    }

    DMP::DVec goal = boost::get<DMP::DVec>(configs[DMP_PARAMETERS_GOAL]);

    double temporalFactor;

    if (configs.find(DMP_PARAMETERS_TEMPORALFACTOR) != configs.end())
    {
        temporalFactor = boost::get<double>(configs[DMP_PARAMETERS_TEMPORALFACTOR]);
    }
    else
    {
        temporalFactor = 1.0;
    }

    return basicdmp->calculateTrajectoryPoint(t, goal, tInit, initialStates, canonicalValues, temporalFactor);
}

DMP::Vec<DMP::_3rdOrderDMP> DMPComponent::calculateNextState(DMP::Vec<DMP::_3rdOrderDMP>& initialStates, double t, double tInit, DMP::DVec& canonicalValues)
{
    if (configs.find(DMP_PARAMETERS_GOAL) == configs.end())
    {
        ARMARX_ERROR << "The goal of DMP must be specified";
    }

    DMP::DVec goal = boost::get<DMP::DVec>(configs[DMP_PARAMETERS_GOAL]);

    double temporalFactor;

    if (configs.find(DMP_PARAMETERS_TEMPORALFACTOR) != configs.end())
    {
        temporalFactor = boost::get<double>(configs[DMP_PARAMETERS_TEMPORALFACTOR]);
    }
    else
    {
        temporalFactor = 1.0;
    }

    return dmp3rdorder->calculateTrajectoryPoint(t, goal, tInit, initialStates, canonicalValues, temporalFactor);
}


void DMPComponent::readTrajectoryFromFile(const std::string& file, const Ice::Current&)
{

    std::string ext = file.rfind(".") == file.npos ? file : file.substr(file.rfind(".") + 1);

    if (ext == "xml")
    {
        DMP::SampledTrajectoryV2 traj;
        MMM::MotionReaderXML motionreader;
        MMM::MotionPtr motion = motionreader.loadMotion(file);
        traj = DMP::MMMConverter::fromMMMJoints(motion);
        trajs.push_back(traj);
    }
    else if (ext == "csv")
    {
        DMP::SampledTrajectoryV2 traj;
        traj.readFromCSVFile(file);
        trajs.push_back(traj);

    }
    else if (ext == "vsg")
    {
        trajs.clear();
        TestDataSet set;
        set.readFromFile(file);
        set.cutConstantBeginning();
        DMP::SampledTrajectoryV2 testTraj;//= set.trajectoryForPositionsV2();

        trajs.push_back(testTraj);
        set.trajectoryForPositionsV2(trajs[0], usedDimensions);
        //DVec anchorPoint;//delete set;
        //anchorPoint.push_back(0.0);
        //anchorPoint.push_back(1.0);
        //anchorPoint.push_back(2.0);
        //trajs[0].m_anchorPoint = anchorPoint;

        ARMARX_INFO << "VSG File LOADED";
        //trajs.resize(1);
        //trajs[1] = traji;

        ARMARX_INFO << "VSG File LOADED";
        printf("Size of traj %f %f\n", trajs[0].m_anchorPoint[0], trajs[0].m_anchorPoint[1]);
        ARMARX_INFO << "VSG File LOADED";
        //ARMARX_INFO << "Size of loaded Trajectories " << trajs.size() << " and " << traj.getAnchorPoint();
        ARMARX_INFO << "VSG File LOADED DONE";
    }
    else
    {
        ARMARX_ERROR << "Error: The file is not valid ";
        return;
    }

}

void DMPComponent::setDMPConfiguration()
{

    //    for(configMap::iterator it = configs.begin(); it != configs.end(); it++){
    //        if(getBasicDMPType() == ARMARX_DMPTYPE_BASICDMP){
    //            basicdmp->setConfiguration(it->first, it->second);
    //        }else if(getBasicDMPType() == ARMARX_DMPTYPE_DMP3RDORDER){
    //            dmp3rdorder->setConfiguration(it->first, it->second);
    //        }
    //    }

}

configMap DMPComponent::constructConfigMap(DMP::Vec<int> paraIDs, DMP::Vec<paraType> paraVals)
{
    if (paraIDs.size() != paraVals.size())
    {
        ARMARX_WARNING << "ID list and value list have different sizes, which may cause error.";
    }

    for (size_t i = 0; i < paraIDs.size(); i++)
    {
        if (configs.find(paraIDs[i]) == configs.end())
        {
            configs.insert(configPair(paraIDs[i], paraVals[i]));
        }
        else
        {
            configs[paraIDs[i]] = paraVals[i];
        }
    }

    return configs;
}


std::string DMPComponent::getDMPTypeName()
{
    switch (dmpType)
    {
        case ARMARX_DMPTYPE_BASICDMP:
            return "BasicDMP";

        case ARMARX_DMPTYPE_ENDVELODMP:
            return "EndVeloDMP";

        case ARMARX_DMPTYPE_SIMPLEENDVELODMP:
            return "SimpleEndVeloDMP";

        case ARMARX_DMPTYPE_FORCEFIELDDMP:
            return "ForceFieldDMP";

        case ARMARX_DMPTYPE_ENDVELFORCEFILELDDMP:
            return "EndVeloForceFieldDMP";

        case ARMARX_DMPTYPE_PERIODICTRANSIENTDMP:
            return "PeriodicTransientDMP";

        case ARMARX_DMPTYPE_DMP3RDORDER:
            return "3rdOrderDMP";

        case ARMARX_DMPTYPE_DMP3RDORDERFORCEFIELD:
            return "3rdOrderForceFieldDMP";

        case ARMARX_DMPTYPE_ADAPTIVEGOAL3RDORDERDMP:
            return "AdaptiveGoal3rdOrderDMP";

        case ARMARX_DMPTYPE_QUATERNIONDMP:
            return "QuaternionDMP";

        default:
            ARMARX_ERROR << "This is not a valid dmp type";
            return "unknownDMP";

    }
}


void DMPComponent::setDMPState(const ::armarx::cStateVec& state, const ::Ice::Current&)
{

    if (!is3rdOrder)
    {
        currentDMPState.resize(state.size());

        for (size_t i = 0; i < state.size(); i++)
        {
            currentDMPState[i].pos = state[i].pos;
            currentDMPState[i].vel = state[i].vel;

        }
    }
    else
    {
        currentDMP3rdOrder.resize(state.size());

        for (size_t i = 0; i < state.size(); i++)
        {
            currentDMP3rdOrder[i].pos = state[i].pos;
            currentDMP3rdOrder[i].vel = state[i].vel;
            currentDMP3rdOrder[i].acc = state[i].acc;
        }
    }

}

void DMPComponent::setParameter(const int paraID, double value, const ::Ice::Current&)
{
    if (configs.find(paraID) == configs.end())
    {
        configs.insert(configPair(paraID, value));
    }
    else
    {
        configs[paraID] = value;
    }
}

void DMPComponent::setGoal(const DVector& value, const Ice::Current&)
{
    if (configs.find(DMP_PARAMETERS_GOAL) == configs.end())
    {
        configs.insert(configPair(DMP_PARAMETERS_GOAL, DMP::DVec(value)));
    }
    else
    {
        configs[DMP_PARAMETERS_GOAL] = DMP::DVec(value);
    }
}

void DMPComponent::setStartPosition(const DVector& value, const Ice::Current&)
{
    if (configs.find(DMP_PARAMETERS_STARTPOSITION) == configs.end())
    {
        configs.insert(configPair(DMP_PARAMETERS_STARTPOSITION, DMP::DVec(value)));
    }
    else
    {
        configs[DMP_PARAMETERS_STARTPOSITION] = DMP::DVec(value);
    }
}

void DMPComponent::setTimeStamps(const DVector& value, const Ice::Current&)
{
    timestamps = DMP::DVec(value);
}

void DMPComponent::setCanonicalValues(const DVector& value, const Ice::Current&)
{

    canonicalValues = DMP::DVec(value);
}


::armarx::cStateVec getcStateVec(const DMP::Vec<DMP::DMPState>& dmpstate)
{
    ::armarx::cStateVec sv;
    sv.resize(dmpstate.size());

    for (size_t i = 0; i < dmpstate.size(); i++)
    {
        sv[i].pos = dmpstate[i].pos;
        sv[i].vel = dmpstate[i].vel;
    }

    return sv;
}

::armarx::cStateVec getcStateVec(const DMP::Vec<DMP::_3rdOrderDMP>& dmpstate)
{
    ::armarx::cStateVec sv;
    sv.resize(dmpstate.size());

    for (size_t i = 0; i < dmpstate.size(); i++)
    {
        sv[i].pos = dmpstate[i].pos;
        sv[i].vel = dmpstate[i].vel;
        sv[i].acc = dmpstate[i].acc;
    }

    return sv;
}

::armarx::cStateVec DMPComponent::getNextState(const ::Ice::Current&)
{
    ARMARX_INFO << "In getNext State ";
    setDMPConfiguration();

    if (timestamps.size() == 0)
    {
        ARMARX_ERROR << "Timestampes must be specified.";
    }
    ARMARX_INFO << "In getNext State 1";
    if (canonicalValues.size() == 0)
    {
        ARMARX_WARNING << "Canonical value is not specified. It will be set 1.0.";
        canonicalValues.push_back(1.0);
    }

    ARMARX_INFO << "In getNext State 2";

    double tInit = timestamps[tCurrent++];

    if (timestamps.size() <= tCurrent)
    {
        ARMARX_ERROR << "Unable to get next state, because time is gone.";

        if (!is3rdOrder)
        {
            return getcStateVec(currentDMPState);
        }
        else
        {
            return getcStateVec(currentDMP3rdOrder);
        }

        tCurrent = 0;
    }
    ARMARX_INFO << "In getNext State 3";
    double t = timestamps[tCurrent];


    if (!is3rdOrder)
    {
        if (currentDMPState.size() == 0)
        {
            ARMARX_ERROR << "The current state is not available. Please specify current state with setDMPState().";
        }

        currentDMPState = calculateNextState(currentDMPState, t, tInit, canonicalValues);
        return getcStateVec(currentDMPState);
    }
    else
    {
        if (currentDMP3rdOrder.size() == 0)
        {
            ARMARX_ERROR << "The current state is not available. Please specify current state with setDMPState().";
        }

        currentDMP3rdOrder = calculateNextState(currentDMP3rdOrder, t, tInit, canonicalValues);
        return getcStateVec(currentDMP3rdOrder);
    }
}


void DMPComponent::removeDMPFromDatabase(const std::string& dmpName, const ::Ice::Current&)
{
    ARMARX_INFO << "removing DMP from database";

    if (!dmpDataMemoryPrx->hasEntityByName(dmpName))
    {
        ARMARX_ERROR << "DMP with name " + dmpName + " does not exist in the database";
        return;
    }

    memoryx::DMPEntityPtr dmpEntity = memoryx::DMPEntityPtr::dynamicCast(dmpDataMemoryPrx->getDMPEntityByName(dmpName));
    dmpDataMemoryPrx->removeEntity(dmpEntity->getId());

    if (!dmpDataMemoryPrx->hasEntityByName(dmpName))
    {
        ARMARX_INFO << "successfully removed dmp from database";
    }
}

void DMPComponent::setDimensionsToLearn(const DVector& value, const Ice::Current&)
{
    usedDimensions.clear();
    for (int i = 0; i < value.size(); i++)
    {
        usedDimensions.push_back(int(value[i]));
    }
}


//PropertyDefinitionsPtr DMPComponent::createPropertyDefinitions()
//{
//    return PropertyDefinitionsPtr(new DMPComponentPropertyDefinitions(
//                                      getConfigIdentifier()));
//}