Skip to content
Snippets Groups Projects
Commit 89b98026 authored by You Zhou's avatar You Zhou
Browse files

fixed the joint space DMP for control

parent 28d38e1c
No related branches found
No related tags found
No related merge requests found
......@@ -30,9 +30,7 @@ module armarx
class NJointJointSpaceDMPControllerConfig extends NJointControllerConfig
{
Ice::StringSeq jointNames;
float DMPKd = 20;
int kernelSize = 100;
double tau = 1;
int baseMode = 1;
double phaseL = 10;
......
......@@ -16,12 +16,15 @@ namespace armarx
NJointJSDMPController::NJointJSDMPController(const RobotUnitPtr&, const armarx::NJointControllerConfigPtr& config, const VirtualRobot::RobotPtr&)
{
ARMARX_INFO << "creating joint space dmp controller ... ";
useSynchronizedRtRobot();
cfg = NJointJointSpaceDMPControllerConfigPtr::dynamicCast(config);
ARMARX_CHECK_EXPRESSION_W_HINT(cfg, "Needed type: NJointJointSpaceDMPControllerConfigPtr");
for (std::string jointName : cfg->jointNames)
{
ControlTargetBase* ct = useControlTarget(jointName, ControlModes::VelocityTorque);
ControlTargetBase* ct = useControlTarget(jointName, ControlModes::Velocity1DoF);
const SensorValueBase* sv = useSensorValue(jointName);
targets.insert(std::make_pair(jointName, ct->asA<ControlTarget1DoFActuatorVelocity>()));
positionSensors.insert(std::make_pair(jointName, sv->asA<SensorValue1DoFActuatorPosition>()));
......@@ -31,8 +34,9 @@ namespace armarx
{
ARMARX_ERROR << "cfg->jointNames.size() == 0";
}
ARMARX_INFO << "start creating dmpPtr ... ";
dmpPtr.reset(new DMP::UMIDMP(cfg->kernelSize, cfg->DMPKd, cfg->baseMode, 1));
dmpPtr.reset(new DMP::UMIDMP(cfg->kernelSize, 100, cfg->baseMode, 1));
timeDuration = cfg->timeDuration;
phaseL = cfg->phaseL;
phaseK = cfg->phaseK;
......@@ -40,6 +44,7 @@ namespace armarx
phaseDist1 = cfg->phaseDist1;
phaseKp = cfg->phaseKp;
dimNames = cfg->jointNames;
ARMARX_INFO << "created dmpPtr ... ";
targetVels.resize(cfg->jointNames.size());
NJointJSDMPControllerControlData initData;
......@@ -47,6 +52,7 @@ namespace armarx
for (size_t i = 0; i < cfg->jointNames.size(); ++i)
{
initData.targetJointVels[i] = 0;
targetVels[i] = 0;
}
reinitTripleBuffer(initData);
......@@ -59,11 +65,14 @@ namespace armarx
controllerSensorData.reinitAllBuffers(initSensorData);
deltaT = 0;
qpos.resize(dimNames.size());
qvel.resize(dimNames.size());
}
void NJointJSDMPController::controllerRun()
{
if (!started)
if (!started || finished)
{
for (size_t i = 0; i < dimNames.size(); ++i)
{
......@@ -74,6 +83,7 @@ namespace armarx
{
currentState = controllerSensorData.getUpToDateReadBuffer().currentState;
double deltaT = controllerSensorData.getUpToDateReadBuffer().deltaT;
if (canVal > 1e-8)
{
double phaseStop = 0;
......@@ -127,8 +137,9 @@ namespace armarx
{
double vel0 = tau * currentState[i].vel / timeDuration;
double vel1 = phaseKp * (targetState[i] - currentPosition[i]);
double vel = mpcFactor * vel0 + (1 - mpcFactor) * vel1;
targetVels[i] = finished ? 0.0f : vel;
// double vel = mpcFactor * vel0 + (1 - mpcFactor) * vel1;
double vel = vel1 + vel0;
targetVels[i] = vel;
debugOutputData.getWriteBuffer().latestTargetVelocities[dimNames[i]] = vel;
}
......@@ -161,16 +172,25 @@ namespace armarx
DMP::DMPState currentPos;
currentPos.pos = (positionSensors.count(jointName) == 1) ? positionSensors[jointName]->position : 0.0f;
currentPos.vel = (velocitySensors.count(jointName) == 1) ? velocitySensors[jointName]->velocity : 0.0f;
qpos[i] = currentPos.pos;
qvel[i] = currentPos.vel;
controllerSensorData.getWriteBuffer().currentState[i] = currentPos;
}
controllerSensorData.getWriteBuffer().deltaT = timeSinceLastIteration.toSecondsDouble();
controllerSensorData.getWriteBuffer().currentTime += timeSinceLastIteration.toSecondsDouble();
controllerSensorData.commitWrite();
std::vector<double> targetJointVels = rtGetControlStruct().targetJointVels;
rt2UserData.getWriteBuffer().qpos = qpos;
rt2UserData.getWriteBuffer().qvel = qvel;
rt2UserData.commitWrite();
Eigen::VectorXf targetJointVels = rtGetControlStruct().targetJointVels;
// ARMARX_INFO << targetJointVels;
for (size_t i = 0; i < dimNames.size(); ++i)
{
if (fabs(targetJointVels[i]) > cfg->maxJointVel)
{
targets[dimNames[i]]->velocity = targetJointVels[i] < 0 ? -cfg->maxJointVel : cfg->maxJointVel;
......@@ -179,6 +199,7 @@ namespace armarx
{
targets[dimNames[i]]->velocity = targetJointVels[i];
}
}
......@@ -206,13 +227,12 @@ namespace armarx
}
}
dmpPtr->learnFromTrajectories(trajs);
dmpPtr->setOneStepMPC(true);
dmpPtr->styleParas = dmpPtr->getStyleParasWithRatio(ratios);
ARMARX_INFO << "Learned DMP ... ";
}
void NJointJSDMPController::runDMP(const Ice::DoubleSeq& goals, double tau, const Ice::Current&)
void NJointJSDMPController::runDMP(const Ice::DoubleSeq& goals, double times, const Ice::Current&)
{
while (!rt2UserData.updateReadBuffer())
{
......@@ -223,6 +243,7 @@ namespace armarx
targetState.resize(dimNames.size());
currentState.clear();
currentState.resize(dimNames.size());
std::vector<double> goalVec = goals;
for (size_t i = 0; i < dimNames.size(); i++)
{
DMP::DMPState currentPos;
......@@ -231,15 +252,42 @@ namespace armarx
currentState[i] = currentPos;
targetState.push_back(currentPos.pos);
}
dmpPtr->prepareExecution(goals, currentState, 1, 1);
if (rtGetRobot()->getRobotNode(dimNames[i])->isLimitless())
{
double tjv = goalVec[i];
double cjv = currentPos.pos;
double diff = std::fmod(tjv - cjv, 2 * M_PI);
if (fabs(diff) > M_PI)
{
if (signbit(diff))
{
diff = - 2 * M_PI - diff;
}
else
{
diff = 2 * M_PI - diff;
}
tjv = cjv - diff;
}
else
{
tjv = cjv + diff;
}
goalVec[i] = tjv;
ARMARX_INFO << "dim name: " << dimNames[i] << " current state: qpos: " << currentPos.pos << " orig target: " << goals[i] << " current goal: " << tjv;
}
this->goals = goals;
canVal = timeDuration * tau;
}
dmpPtr->prepareExecution(goalVec, currentState, 1, 1);
canVal = timeDuration;
finished = false;
isDisturbance = false;
tau = times;
ARMARX_INFO << "run DMP";
started = true;
......@@ -300,8 +348,6 @@ namespace armarx
void NJointJSDMPController::onDisconnectNJointController()
{
controllerTask->stop();
ARMARX_INFO << "stopped ...";
}
......
......@@ -20,7 +20,7 @@ namespace armarx
class NJointJSDMPControllerControlData
{
public:
std::vector<double> targetJointVels;
Eigen::VectorXf targetJointVels;
};
/**
......@@ -49,7 +49,7 @@ namespace armarx
void learnDMPFromFiles(const Ice::StringSeq& fileNames, const Ice::Current&) override;
void setSpeed(double times, const Ice::Current&) override;
void runDMP(const Ice::DoubleSeq& goals, double tau, const Ice::Current&) override;
void runDMP(const Ice::DoubleSeq& goals, double times, const Ice::Current&) override;
void showMessages(const Ice::Current&) override;
......@@ -84,15 +84,14 @@ namespace armarx
};
TripleBuffer<RTToUserData> rt2UserData;
std::map<std::string, const SensorValue1DoFActuatorPosition*> positionSensors;
std::map<std::string, const SensorValue1DoFActuatorVelocity*> velocitySensors;
std::map<std::string, ControlTarget1DoFActuatorVelocity*> targets;
IceUtil::Time last;
std::vector<double> goals;
DMP::UMIDMPPtr dmpPtr;
bool DMPAsForwardControl;
double timeDuration;
DMP::Vec<DMP::DMPState> currentState;
double canVal;
......@@ -112,10 +111,12 @@ namespace armarx
bool started;
std::vector<std::string> dimNames;
DMP::DVec targetState;
std::vector<double> targetVels;
Eigen::VectorXf targetVels;
mutable MutexType controllerMutex;
PeriodicTask<NJointJSDMPController>::pointer_type controllerTask;
Eigen::VectorXf qpos;
Eigen::VectorXf qvel;
// ManagedIceObject interface
protected:
void controllerRun();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment