Skip to content
Snippets Groups Projects
Commit e6a1301a authored by Philip Scherer's avatar Philip Scherer
Browse files

Add linear predictions to Proprioception

parent d7f06d63
No related branches found
No related tags found
1 merge request!265Robot state predictions
......@@ -25,12 +25,13 @@
#include <RobotAPI/interface/core/PoseBase.h>
#include <RobotAPI/libraries/core/Pose.h>
#include <RobotAPI/libraries/armem/core/Prediction.h>
#include <RobotAPI/libraries/armem_robot_state/server/proprioception/aron_conversions.h>
#include <RobotAPI/libraries/armem_robot_state/server/common/Visu.h>
#include <RobotAPI/libraries/RobotAPIComponentPlugins/RobotUnitComponentPlugin.h>
#include <ArmarXCore/core/exceptions/local/ExpressionException.h>
#include <ArmarXCore/core/ice_conversions/ice_conversions_templates.h>
#include <ArmarXCore/libraries/ArmarXCoreComponentPlugins/DebugObserverComponentPlugin.h>
#include <ArmarXCore/core/logging/Logging.h>
......@@ -204,6 +205,31 @@ namespace armarx::armem::server::robot_state
return { new Pose(poseMap[robotName].matrix()) };
}
armem::prediction::data::PredictionResultSeq
RobotStateMemory::predict(const armem::prediction::data::PredictionRequestSeq& requests)
{
std::vector<armem::prediction::data::PredictionResult> results;
for (const auto& request : requests)
{
auto boRequest = armarx::fromIce<armem::PredictionRequest>(request);
armem::PredictionResult result;
if (armem::contains(workingMemory().id().withCoreSegmentName("Proprioception"),
boRequest.snapshotID) &&
!boRequest.snapshotID.hasGap() && boRequest.snapshotID.hasTimestamp())
{
result = proprioceptionSegment.predict(boRequest);
}
else
{
result.success = false;
result.errorMessage << "No predictions are supported for MemoryID "
<< boRequest.snapshotID;
}
results.push_back(result.toIce());
}
return results;
}
/*************************************************************/
// RobotUnit Streaming functions
......
......@@ -79,6 +79,11 @@ namespace armarx::armem::server::robot_state
// GlobalRobotPoseProvider interface
armarx::PoseBasePtr getGlobalRobotPose(Ice::Long timestamp, const std::string& robotName, const ::Ice::Current&) override;
using ReadWritePluginUser::predict;
armem::prediction::data::PredictionResultSeq
predict(const armem::prediction::data::PredictionRequestSeq& requests) override;
protected:
......
#include "Segment.h"
#include <SimoxUtility/math/regression/linear.hpp>
#include <ArmarXCore/core/application/properties/PropertyDefinitionContainer.h>
#include <ArmarXCore/core/exceptions/local/ExpressionException.h>
#include <ArmarXCore/core/logging/Logging.h>
......@@ -16,12 +18,26 @@ namespace armarx::armem::server::robot_state::proprioception
{
Segment::Segment(armem::server::MemoryToIceAdapter& memoryToIceAdapter) :
Base(memoryToIceAdapter, "Proprioception", arondto::Proprioception::ToAronType(), 1024)
Base(memoryToIceAdapter,
"Proprioception",
arondto::Proprioception::ToAronType(),
1024,
{{"Linear"}})
{
}
Segment::~Segment() = default;
void
Segment::defineProperties(armarx::PropertyDefinitionsPtr defs, const std::string& prefix)
{
Base::defineProperties(defs, prefix);
defs->optional(properties.predictionTimeWindow,
"prediction.TimeWindow",
"Duration of time window into the past to use for predictions"
" when requested via the PredictingMemoryInterface (in seconds).");
}
void Segment::onConnect(RobotUnitInterfacePrx robotUnitPrx)
{
......@@ -129,6 +145,165 @@ namespace armarx::armem::server::robot_state::proprioception
return robotUnitProviderID;
}
Eigen::VectorXd readJointData(const wm::EntityInstanceData& data)
{
namespace adn = aron::data;
std::vector<double> values;
auto addData =
[&](adn::DictPtr dict) // NOLINT
{
for (const auto& [name, value] : dict->getElements())
{
values.push_back(
static_cast<double>(adn::Float::DynamicCastAndCheck(value)->getValue()));
}
};
if (adn::DictPtr joints = getDictElement(data, "joints"))
{
if (adn::DictPtr jointsPosition = getDictElement(*joints, "position"))
{
addData(jointsPosition);
}
if (adn::DictPtr jointsVelocity = getDictElement(*joints, "velocity"))
{
addData(jointsVelocity);
}
if (adn::DictPtr jointsTorque = getDictElement(*joints, "torque"))
{
addData(jointsTorque);
}
}
Eigen::VectorXd vec =
Eigen::Map<Eigen::VectorXd>(values.data(), static_cast<Eigen::Index>(values.size()));
return vec;
}
void
emplaceJointData(const Eigen::VectorXd& jointData,
arondto::Proprioception& dataTemplate)
{
Eigen::Index row = 0;
for (auto& [joint, value] : dataTemplate.joints.position)
{
value = static_cast<float>(jointData(row++));
}
for (auto& [joint, value] : dataTemplate.joints.velocity)
{
value = static_cast<float>(jointData(row++));
}
for (auto& [joint, value] : dataTemplate.joints.torque)
{
value = static_cast<float>(jointData(row++));
}
}
armem::PredictionResult
Segment::predict(const armem::PredictionRequest& request) const
{
PredictionResult result;
result.snapshotID = request.snapshotID;
if (request.predictionSettings.predictionEngineID != "Linear")
{
result.success = false;
result.errorMessage = "Prediction engine " +
request.predictionSettings.predictionEngineID +
" is not supported in Proprioception.";
return result;
}
aron::data::DictPtr valueTemplate;
std::vector<double> timestampsSec;
std::vector<Eigen::VectorXd> jointValues;
aron::data::DictPtr latestData;
const DateTime timeOrigin = DateTime::Now();
const int instanceIndex = 0;
doLocked(
// Default capture because the number of variables was getting out of hand
[&, this]()
{
// Use result.success as a marker for whether to continue later
result.success = false;
const armarx::Duration timeWindow =
Duration::SecondsDouble(properties.predictionTimeWindow);
wm::Entity* entity = segmentPtr->findEntity(request.snapshotID);
if (entity == nullptr)
{
std::stringstream sstream;
sstream << "Could not find entity with ID " << request.snapshotID << ".";
result.errorMessage = sstream.str();
return;
}
bool querySuccess = true;
entity->forEachSnapshotInTimeRange(
Time::Now() - timeWindow,
Time::Invalid(),
[&](
const wm::EntitySnapshot& snapshot)
{
const auto* instance = snapshot.findInstance(instanceIndex);
if (instance == nullptr)
{
std::stringstream sstream;
sstream << "Could not find instance with index " << instanceIndex
<< " in snapshot " << snapshot.id() << ".";
result.errorMessage = sstream.str();
querySuccess = false;
return;
}
timestampsSec.push_back(
(instance->id().timestamp - timeOrigin).toSecondsDouble());
jointValues.emplace_back(readJointData(*instance->data()));
valueTemplate = instance->data();
});
if (!querySuccess)
{
return;
}
latestData = entity->findLatestInstanceData(instanceIndex);
if (latestData == nullptr)
{
std::stringstream sstream;
sstream << "Could not find instance with index " << instanceIndex
<< " for entity " << entity->id() << ".";
result.errorMessage = sstream.str();
return;
}
result.success = true;
});
if (!result.success)
{
return result;
}
Eigen::VectorXd latestJoints = readJointData(*latestData);
Eigen::VectorXd prediction(latestJoints.size());
if (timestampsSec.size() <= 1)
{
prediction = latestJoints;
}
else
{
using simox::math::LinearRegression;
const bool inputOffset = false;
const LinearRegression model =
LinearRegression<Eigen::Dynamic>::Fit(timestampsSec, jointValues, inputOffset);
const auto predictionTime = request.snapshotID.timestamp;
prediction = model.predict((predictionTime - timeOrigin).toSecondsDouble());
}
arondto::Proprioception templateData = arondto::Proprioception::FromAron(latestData);
emplaceJointData(prediction, templateData);
result.success = true;
result.prediction = templateData.toAron();
return result;
}
std::map<std::string, float>
Segment::readJointPositions(const wm::EntityInstanceData& data)
......@@ -151,4 +326,4 @@ namespace armarx::armem::server::robot_state::proprioception
return jointPositions;
}
} // namespace armarx::armem::server::robot_state::proprioception
} // namespace armarx::armem::server::robot_state::proprioception
......@@ -32,6 +32,7 @@
// RobotAPI
#include <RobotAPI/libraries/armem/core/MemoryID.h>
#include <RobotAPI/libraries/armem/core/Prediction.h>
#include <RobotAPI/libraries/armem/server/segment/SpecializedCoreSegment.h>
#include <RobotAPI/libraries/armem/server/segment/SpecializedProviderSegment.h>
#include <RobotAPI/libraries/armem_robot_state/server/forward_declarations.h>
......@@ -52,6 +53,8 @@ namespace armarx::armem::server::robot_state::proprioception
Segment(server::MemoryToIceAdapter& iceMemory);
virtual ~Segment() override;
void defineProperties(armarx::PropertyDefinitionsPtr defs, const std::string& prefix = "") override;
void onConnect(RobotUnitInterfacePrx robotUnitPrx);
......@@ -62,6 +65,8 @@ namespace armarx::armem::server::robot_state::proprioception
const armem::MemoryID& getRobotUnitProviderID() const;
armem::PredictionResult predict(const armem::PredictionRequest& request) const;
private:
......@@ -75,6 +80,12 @@ namespace armarx::armem::server::robot_state::proprioception
RobotUnitInterfacePrx robotUnit;
armem::MemoryID robotUnitProviderID;
struct Properties
{
double predictionTimeWindow = 2;
};
Properties properties;
// Debug Observer prefix
const std::string dp = "Proprioception::getRobotJointPositions() | ";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment