Skip to content
Snippets Groups Projects
DMPComponent.cpp 13.6 KiB
Newer Older
Nikolaus Vahrenkamp's avatar
Nikolaus Vahrenkamp committed
/*
 * This file is part of ArmarX.
 *
 * ArmarX is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 *
 * ArmarX is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Nikolaus Vahrenkamp's avatar
Nikolaus Vahrenkamp committed
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 *
 * @package    RobotAPI::ArmarXObjects::DMPComponent
 * @author     Mirko Waechter ( mirko dot waechter at kit dot edu )
 * @date       2015
 * @copyright  http://www.gnu.org/licenses/gpl-2.0.txt
Nikolaus Vahrenkamp's avatar
Nikolaus Vahrenkamp committed
 *             GNU General Public License
 */

#include "DMPComponent.h"
#include <MMM/Motion/MotionReaderXML.h>
#include <dmp/io/MMMConverter.h>


using namespace armarx;


void DMPComponent::onInitComponent()
{
    ARMARX_INFO << "initializing DMP component";
    usingProxy(getProperty<std::string>("LongtermMemoryName").getValue());
    ARMARX_INFO << "successfully initialized DMP component" ;
}


void DMPComponent::onConnectComponent()
{
    ARMARX_INFO << "connecting DMP component";
    try
    {
        longtermMemoryPrx = getProxy<memoryx::LongtermMemoryInterfacePrx>(getProperty<std::string>("LongtermMemoryName").getValue());
    } catch (...)
    {
        ARMARX_ERROR << "cannot get longterm memory proxy";
        return;
    }
    try
    {
        dmpDataMemoryPrx = longtermMemoryPrx->getDMPSegment();
    } catch (...)
    {
        ARMARX_ERROR << "cannot get dmp segment of longterm memory";
        return;
    }

    ARMARX_INFO << "successfully connected DMP component";
}


void DMPComponent::onDisconnectComponent()
{
    ARMARX_INFO << "disconnecting DMP component";
}


void DMPComponent::onExitComponent()
{
    ARMARX_INFO << "exiting DMP component";
}


void DMPComponent::instantiateDMP(int DMPType, const Ice::Current &){
    ARMARX_INFO << "instantiate DMP";
    dmpType = DMPType;

    switch(dmpType){
    case ARMARX_DMPTYPE_BASICDMP: basicdmp.reset(new DMP::BasicDMP); break;
    case ARMARX_DMPTYPE_DMP3RDORDER: dmp3rdorder.reset(new DMP::DMP3rdOrder); break;
    case ARMARX_DMPTYPE_DMP3RDORDERFORCEFIELD: dmp3rdorder.reset(new DMP::DMP3rdOrderForceField); break;
    case ARMARX_DMPTYPE_ENDVELODMP: basicdmp.reset(new DMP::EndVeloDMP); break;
    case ARMARX_DMPTYPE_FORCEFIELDDMP: basicdmp.reset(new DMP::ForceFieldDMP); break;
    case ARMARX_DMPTYPE_SIMPLEENDVELODMP: basicdmp.reset(new DMP::SimpleEndVeloDMP); break;
//    case ARMARX_DMPTYPE_ENDVELFORCEFILELDDMP: basicdmp.reset(new DMP::EndVeloForceFieldDMP);break;
//    case ARMARX_DMPTYPE_QUATERNIONDMP: basicdmp.reset(new DMP::QuaternionDMP); break; //error: quanternionDMP is an abstract type.
//    case ARMARX_DMPTYPE_ADAPTIVEGOAL3RDORDERDMP: basicdmp.reset(new DMP::AdaptiveGoal3rdOrderDMP); break; //error:
//    case ARMARX_DMPTYPE_PERIODICTRANSIENTDMP: basicdmp.reset(new DMP::PeriodicTransientDMP); break;
    default: ARMARX_ERROR << "ERROR: It is not a valid dmp type. " ;
    }

}


void DMPComponent::storeDMPInDatabase(const std::string &name, const ::Ice::Current&){

    ARMARX_INFO << "storing DMP in the database";
    memoryx::DMPEntityPtr dmpEntity = new memoryx::DMPEntity(name);

    dmpEntity->setDMPType(dmpType);
    dmpEntity->setDMPName(name);
    dmpEntity->set3rdOrder(is3rdOrder);

    std::stringstream dmptext;
    boost::archive::text_oarchive ar(dmptext);


    if(is3rdOrder){
        dmpEntity->set3rdOrder(true);
        ar << boost::serialization::make_nvp("dmp",*dmp3rdorder);
    }else{
        dmpEntity->set3rdOrder(false);
        ar << boost::serialization::make_nvp("dmp",*basicdmp);
    }
    dmpEntity->setDMPtextStr(dmptext.str());

    std::cout << dmpEntity << std::endl;
    const std::string entityID = dmpDataMemoryPrx->addEntity(dmpEntity);
    dmpEntity->setId(entityID);

    ARMARX_INFO << "successfully stored DMP";

}


void DMPComponent::getDMPFromDatabase(const std::string &dmpName, const Ice::Current &){
    ARMARX_INFO << "getting DMP from database";
    if(!dmpDataMemoryPrx->hasEntityByName(dmpName)){
        ARMARX_ERROR << "DMP with name " + dmpName + " does not exist in the database";
        return;
    }

    memoryx::DMPEntityPtr dmpEntity = memoryx::DMPEntityPtr::dynamicCast(dmpDataMemoryPrx->getDMPEntityByName(dmpName));

    std::string name = dmpEntity->getDMPName();
    ARMARX_INFO << "DMP with name: " + name + " is loaded";

    dmpType = dmpEntity->getDMPType();
    ARMARX_INFO << "dmp type is " + std::to_string(dmpType);

    is3rdOrder = dmpEntity->get3rdOrder();
    ARMARX_INFO << "is3rdOrder is " + std::to_string(is3rdOrder);

    std::string textStr = dmpEntity->getDMPtextStr();
    ARMARX_INFO << textStr;
    std::stringstream istr;
    istr.str(textStr);

    boost::archive::text_iarchive ar(istr);

    if(is3rdOrder)
        ar >> boost::serialization::make_nvp("dmp", *dmp3rdorder);
    else
        ar >> boost::serialization::make_nvp("dmp", *basicdmp);

    ARMARX_INFO << "successfully got dmp from database.";
}


void DMPComponent::trainDMP(const ::Ice::Current&){
    // learn dmp
    if(getBasicDMPType() == ARMARX_DMPTYPE_BASICDMP){
        basicdmp->learnFromTrajectories(trajs);
    }else if(getBasicDMPType() == ARMARX_DMPTYPE_DMP3RDORDER){
        dmp3rdorder->learnFromTrajectories(trajs);
    }
}

void DMPComponent::calculateWholeTrajectory(){

}

DMP::Vec<DMP::DMPState> DMPComponent::calculateNextState(DMP::Vec<DMP::DMPState>& initialStates, double t, double tInit, DMP::DVec& canonicalValues){

    if(configs.find(DMP_PARAMETERS_GOAL) == configs.end()){
        ARMARX_ERROR << "The goal of DMP must be specified";
    }

    DMP::DVec goal = boost::get<DMP::DVec>(configs[DMP_PARAMETERS_GOAL]);

    double temporalFactor;
    if(configs.find(DMP_PARAMETERS_TEMPORALFACTOR) != configs.end())
        temporalFactor = boost::get<double>(configs[DMP_PARAMETERS_TEMPORALFACTOR]);
    else
        temporalFactor = 1.0;

    return basicdmp->calculateTrajectoryPoint(t, goal, tInit, initialStates, canonicalValues, temporalFactor);
}

DMP::Vec<DMP::_3rdOrderDMP> DMPComponent::calculateNextState(DMP::Vec<DMP::_3rdOrderDMP> &initialStates, double t, double tInit, DMP::DVec& canonicalValues){
    if(configs.find(DMP_PARAMETERS_GOAL) == configs.end()){
        ARMARX_ERROR << "The goal of DMP must be specified";
    }

    DMP::DVec goal = boost::get<DMP::DVec>(configs[DMP_PARAMETERS_GOAL]);

    double temporalFactor;
    if(configs.find(DMP_PARAMETERS_TEMPORALFACTOR) != configs.end())
        temporalFactor = boost::get<double>(configs[DMP_PARAMETERS_TEMPORALFACTOR]);
    else
        temporalFactor = 1.0;

    return dmp3rdorder->calculateTrajectoryPoint(t, goal, tInit, initialStates, canonicalValues, temporalFactor);
}


void DMPComponent::readTrajectoryFromFile(const std::string &file, const Ice::Current &){

    std::string ext = file.rfind(".") == file.npos? file: file.substr(file.rfind(".") + 1);

    if(ext == "xml"){
        DMP::SampledTrajectoryV2 traj;
        MMM::MotionReaderXML motionreader;
        MMM::MotionPtr motion = motionreader.loadMotion(file);
        traj = DMP::MMMConverter::fromMMMJoints(motion);
        trajs.push_back(traj);
    }else if(ext == "csv"){
        DMP::SampledTrajectoryV2 traj;
        traj.readFromCSVFile(file);
        trajs.push_back(traj);
    }else{
        ARMARX_ERROR << "Error: The file is not valid ";
        return;
    }

}


void DMPComponent::setDMPConfiguration(){

//    for(configMap::iterator it = configs.begin(); it != configs.end(); it++){
//        if(getBasicDMPType() == ARMARX_DMPTYPE_BASICDMP){
//            basicdmp->setConfiguration(it->first, it->second);
//        }else if(getBasicDMPType() == ARMARX_DMPTYPE_DMP3RDORDER){
//            dmp3rdorder->setConfiguration(it->first, it->second);
//        }
//    }

}

configMap DMPComponent::constructConfigMap(DMP::Vec<int> paraIDs, DMP::Vec<paraType> paraVals){
    if(paraIDs.size() != paraVals.size()){
        ARMARX_WARNING << "ID list and value list have different sizes, which may cause error.";
    }

    for(size_t i = 0; i < paraIDs.size(); i++){
        if(configs.find(paraIDs[i]) == configs.end())
            configs.insert(configPair(paraIDs[i], paraVals[i]));
        else
            configs[paraIDs[i]] = paraVals[i];
    }
    return configs;
}


std::string DMPComponent::getDMPTypeName(){
    switch(dmpType){
    case ARMARX_DMPTYPE_BASICDMP: return "BasicDMP";
    case ARMARX_DMPTYPE_ENDVELODMP: return "EndVeloDMP";
    case ARMARX_DMPTYPE_SIMPLEENDVELODMP: return "SimpleEndVeloDMP";
    case ARMARX_DMPTYPE_FORCEFIELDDMP: return "ForceFieldDMP";
    case ARMARX_DMPTYPE_ENDVELFORCEFILELDDMP: return "EndVeloForceFieldDMP";
    case ARMARX_DMPTYPE_PERIODICTRANSIENTDMP: return "PeriodicTransientDMP";
    case ARMARX_DMPTYPE_DMP3RDORDER: return "3rdOrderDMP";
    case ARMARX_DMPTYPE_DMP3RDORDERFORCEFIELD: return "3rdOrderForceFieldDMP";
    case ARMARX_DMPTYPE_ADAPTIVEGOAL3RDORDERDMP: return "AdaptiveGoal3rdOrderDMP";
    case ARMARX_DMPTYPE_QUATERNIONDMP: return "QuaternionDMP";
    default: ARMARX_ERROR << "This is not a valid dmp type"; return "unknownDMP";

    }
}


void DMPComponent::setDMPState(const ::armarx::cStateVec &state, const ::Ice::Current& ){

    if(!is3rdOrder){
        currentDMPState.resize(state.size());
        for(size_t i = 0; i < state.size(); i++){
            currentDMPState[i].pos = state[i].pos;
            currentDMPState[i].vel = state[i].vel;

        }
    }else{
        currentDMP3rdOrder.resize(state.size());
        for(size_t i = 0; i < state.size(); i++){
            currentDMP3rdOrder[i].pos = state[i].pos;
            currentDMP3rdOrder[i].vel = state[i].vel;
            currentDMP3rdOrder[i].acc = state[i].acc;
        }
    }

}

void DMPComponent::setParameter(const int paraID, double value, const ::Ice::Current&){
    if(configs.find(paraID) == configs.end())
        configs.insert(configPair(paraID, value));
    else
        configs[paraID] = value;
}

void DMPComponent::setGoal(const DVector &value, const Ice::Current &){
    if(configs.find(DMP_PARAMETERS_GOAL) == configs.end())
        configs.insert(configPair(DMP_PARAMETERS_GOAL, DMP::DVec(value)));
    else
        configs[DMP_PARAMETERS_GOAL] = DMP::DVec(value);
}

void DMPComponent::setStartPosition(const DVector &value, const Ice::Current &){
    if(configs.find(DMP_PARAMETERS_STARTPOSITION) == configs.end())
        configs.insert(configPair(DMP_PARAMETERS_STARTPOSITION, DMP::DVec(value)));
    else
        configs[DMP_PARAMETERS_STARTPOSITION] = DMP::DVec(value);
}

void DMPComponent::setTimeStamps(const DVector &value, const Ice::Current &){
    timestamps = DMP::DVec(value);
}

void DMPComponent::setCanonicalValues(const DVector &value, const Ice::Current &){
    canonicalValues = DMP::DVec(value);
}


::armarx::cStateVec getcStateVec(const DMP::Vec<DMP::DMPState> &dmpstate){
    ::armarx::cStateVec sv;
    sv.resize(dmpstate.size());
    for(size_t i = 0; i < dmpstate.size(); i++){
        sv[i].pos = dmpstate[i].pos;
        sv[i].vel = dmpstate[i].vel;
    }
    return sv;
}

::armarx::cStateVec getcStateVec(const DMP::Vec<DMP::_3rdOrderDMP> &dmpstate){
    ::armarx::cStateVec sv;
    sv.resize(dmpstate.size());
    for(size_t i = 0; i < dmpstate.size(); i++){
        sv[i].pos = dmpstate[i].pos;
        sv[i].vel = dmpstate[i].vel;
        sv[i].acc = dmpstate[i].acc;
    }
    return sv;
}

::armarx::cStateVec DMPComponent::getNextState(const ::Ice::Current&){

    setDMPConfiguration();

    if(timestamps.size() == 0){
        ARMARX_ERROR << "Timestampes must be specified.";
    }

    if(canonicalValues.size() == 0){
        ARMARX_WARNING << "Canonical value is not specified. It will be set 1.0.";
        canonicalValues.push_back(1.0);
    }



    double tInit = timestamps[tCurrent++];

    if(timestamps.size() <= tCurrent){
        ARMARX_ERROR << "Unable to get next state, because time is gone.";
        if(!is3rdOrder)
            return getcStateVec(currentDMPState);
        else
            return getcStateVec(currentDMP3rdOrder);

        tCurrent = 0;
    }

    double t = timestamps[tCurrent];


    if(!is3rdOrder){
        if(currentDMPState.size() == 0){
            ARMARX_ERROR << "The current state is not available. Please specify current state with setDMPState().";
        }
        currentDMPState = calculateNextState(currentDMPState, t, tInit, canonicalValues);
        return getcStateVec(currentDMPState);
    }else{
        if(currentDMP3rdOrder.size() == 0){
            ARMARX_ERROR << "The current state is not available. Please specify current state with setDMPState().";
        }
        currentDMP3rdOrder = calculateNextState(currentDMP3rdOrder, t, tInit, canonicalValues);
        return getcStateVec(currentDMP3rdOrder);
    }
}


void DMPComponent::removeDMPFromDatabase(const std::string &dmpName, const ::Ice::Current&){
    ARMARX_INFO << "removing DMP from database";
    if(!dmpDataMemoryPrx->hasEntityByName(dmpName)){
        ARMARX_ERROR << "DMP with name " + dmpName + " does not exist in the database";
        return;
    }

    memoryx::DMPEntityPtr dmpEntity = memoryx::DMPEntityPtr::dynamicCast(dmpDataMemoryPrx->getDMPEntityByName(dmpName));
    dmpDataMemoryPrx->removeEntity(dmpEntity->getId());

    if(!dmpDataMemoryPrx->hasEntityByName(dmpName)){
        ARMARX_INFO << "successfully removed dmp from database";
    }
}


//PropertyDefinitionsPtr DMPComponent::createPropertyDefinitions()
//{
//    return PropertyDefinitionsPtr(new DMPComponentPropertyDefinitions(
//                                      getConfigIdentifier()));
//}