Skip to content
Snippets Groups Projects
Commit c6816666 authored by Philip Scherer's avatar Philip Scherer
Browse files

Add linear predictions to Localization segment

parent 78eccea4
No related branches found
No related tags found
1 merge request!265Robot state predictions
......@@ -159,6 +159,7 @@ set(LIB_HEADERS
mns/plugins/Plugin.h
mns/plugins/PluginUser.h
util/prediction_helpers.h
util/util.h
)
......
/*
* This file is part of ArmarX.
*
* ArmarX is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation.
*
* ArmarX is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* @author phesch ( ulila at student dot kit dot edu )
* @date 2022
* @copyright http://www.gnu.org/licenses/gpl-2.0.txt
* GNU General Public License
*/
#pragma once
#include <functional>
#include <vector>
#include <ArmarXCore/core/time.h>
#include <RobotAPI/libraries/armem/core/MemoryID.h>
#include <RobotAPI/libraries/armem/server/wm/memory_definitions.h>
#include <RobotAPI/libraries/aron/core/data/variant/container/Dict.h>
namespace armarx::armem
{
template <typename DataType, typename LatestType>
struct LatestSnapshotInfo
{
bool success = false;
std::string errorMessage = "";
std::vector<double> timestampsSec = {};
std::vector<DataType> values = {};
LatestType latestValue;
};
template <typename SegmentType, typename DataType, typename LatestType>
LatestSnapshotInfo<DataType, LatestType>
getLatestSnapshots(const SegmentType* segment,
const MemoryID& entityID,
const DateTime& startTime,
const DateTime& endTime,
std::function<DataType(const aron::data::DictPtr&)> dictToData,
std::function<LatestType(const aron::data::DictPtr&)> dictToLatest)
{
LatestSnapshotInfo<DataType, LatestType> result;
result.success = false;
const server::wm::Entity* entity = segment->findEntity(entityID);
if (entity == nullptr)
{
std::stringstream sstream;
sstream << "Could not find entity with ID " << entityID << ".";
result.errorMessage = sstream.str();
return result;
}
const int instanceIndex = 0;
bool querySuccess = true;
entity->forEachSnapshotInTimeRange(
startTime,
endTime,
[&](const wm::EntitySnapshot& snapshot)
{
const auto* instance = snapshot.findInstance(instanceIndex);
if (instance)
{
result.timestampsSec.push_back(
(instance->id().timestamp - endTime).toSecondsDouble());
result.values.emplace_back(dictToData(instance->data()));
}
else
{
std::stringstream sstream;
sstream << "Could not find instance with index " << instanceIndex
<< " in snapshot " << snapshot.id() << ".";
result.errorMessage = sstream.str();
querySuccess = false;
}
});
if (querySuccess)
{
aron::data::DictPtr latest = entity->findLatestInstanceData(instanceIndex);
if (latest)
{
result.success = true;
result.latestValue = dictToLatest(latest);
}
else
{
std::stringstream sstream;
sstream << "Could not find instance with index " << instanceIndex << " for entity "
<< entity->id() << ".";
result.errorMessage = sstream.str();
}
}
return result;
}
} // namespace armarx::armem
......@@ -3,6 +3,9 @@
// STL
#include <iterator>
#include <SimoxUtility/math/pose/pose.h>
#include <SimoxUtility/math/regression/linear.h>
#include <ArmarXCore/core/logging/Logging.h>
#include <RobotAPI/libraries/core/FramedPose.h>
......@@ -12,6 +15,7 @@
#include <RobotAPI/libraries/armem/core/Time.h>
#include <RobotAPI/libraries/armem/core/aron_conversions.h>
#include <RobotAPI/libraries/armem/server/MemoryToIceAdapter.h>
#include <RobotAPI/libraries/armem/util/prediction_helpers.h>
#include <RobotAPI/libraries/armem_robot/aron/Robot.aron.generated.h>
#include <RobotAPI/libraries/armem_robot/robot_conversions.h>
......@@ -35,6 +39,24 @@ namespace armarx::armem::server::robot_state::localization
{
}
void Segment::defineProperties(armarx::PropertyDefinitionsPtr defs, const std::string& prefix)
{
Base::defineProperties(defs, prefix);
defs->optional(properties.predictionTimeWindow,
"prediction.TimeWindow",
"Duration of time window into the past to use for predictions"
" when requested via the PredictingMemoryInterface (in seconds).");
}
void Segment::init()
{
Base::init();
segmentPtr->addPredictor(armem::PredictionEngine{.engineID = "Linear"},
[this](const PredictionRequest& request){ return this->predictLinear(request); });
}
void Segment::onConnect()
{
......@@ -170,4 +192,70 @@ namespace armarx::armem::server::robot_state::localization
return update;
}
PredictionResult Segment::predictLinear(const PredictionRequest& request)
{
PredictionResult result;
result.snapshotID = request.snapshotID;
if (request.predictionSettings.predictionEngineID != "Linear")
{
result.success = false;
result.errorMessage = "Prediction engine " +
request.predictionSettings.predictionEngineID +
" is not supported in Proprioception.";
return result;
}
const DateTime timeOrigin = DateTime::Now();
const armarx::Duration timeWindow =
Duration::SecondsDouble(properties.predictionTimeWindow);
LatestSnapshotInfo<Eigen::Vector3d, arondto::Transform> info;
doLocked(
[&, this]()
{
info = getLatestSnapshots<server::wm::CoreSegment,
Eigen::Vector3d,
arondto::Transform>(
segmentPtr,
request.snapshotID,
timeOrigin - timeWindow,
timeOrigin,
[](const aron::data::DictPtr& data)
{ return simox::math::position(arondto::Transform::FromAron(data).transform).cast<double>(); },
[](const aron::data::DictPtr& data)
{ return arondto::Transform::FromAron(data); });
});
if (info.success)
{
Eigen::Vector3f latestPosition = simox::math::position(info.latestValue.transform);
Eigen::Vector3f prediction;
if (info.timestampsSec.size() <= 1)
{
prediction = latestPosition;
}
else
{
using simox::math::LinearRegression3d;
const bool inputOffset = false;
const LinearRegression3d model =
LinearRegression3d::Fit(info.timestampsSec, info.values, inputOffset);
const auto predictionTime = request.snapshotID.timestamp;
prediction = model.predict((predictionTime - timeOrigin).toSecondsDouble()).cast<float>();
}
simox::math::position(info.latestValue.transform) = prediction;
result.success = true;
result.prediction = info.latestValue.toAron();
}
else
{
result.success = false;
result.errorMessage = info.errorMessage;
}
return result;
}
} // namespace armarx::armem::server::robot_state::localization
......@@ -49,6 +49,9 @@ namespace armarx::armem::server::robot_state::localization
Segment(server::MemoryToIceAdapter& iceMemory);
virtual ~Segment() override;
void defineProperties(armarx::PropertyDefinitionsPtr defs, const std::string& prefix = "") override;
void init() override;
void onConnect();
......@@ -67,6 +70,13 @@ namespace armarx::armem::server::robot_state::localization
EntityUpdate makeUpdate(const armem::robot_state::Transform& transform) const;
PredictionResult predictLinear(const PredictionRequest& request);
struct Properties
{
double predictionTimeWindow = 2;
};
Properties properties;
};
......
......@@ -11,6 +11,7 @@
#include <RobotAPI/libraries/aron/core/data/variant/All.h>
#include <RobotAPI/libraries/armem/core/MemoryID.h>
#include <RobotAPI/libraries/armem/util/prediction_helpers.h>
#include <RobotAPI/libraries/armem_robot_state/aron/Proprioception.aron.generated.h>
......@@ -223,74 +224,29 @@ namespace armarx::armem::server::robot_state::proprioception
}
const DateTime timeOrigin = DateTime::Now();
const int instanceIndex = 0;
std::vector<double> timestampsSec;
std::vector<Eigen::VectorXd> jointValues;
aron::data::DictPtr latestData;
const armarx::Duration timeWindow = Duration::SecondsDouble(properties.predictionTimeWindow);
// Use result.success as a marker for whether to continue later
result.success = false;
LatestSnapshotInfo<Eigen::VectorXd, aron::data::DictPtr> info;
doLocked(
// Default capture because the number of variables was getting out of hand
[&, this]()
{
wm::Entity* entity = segmentPtr->findEntity(request.snapshotID);
if (entity == nullptr)
{
std::stringstream sstream;
sstream << "Could not find entity with ID " << request.snapshotID << ".";
result.errorMessage = sstream.str();
return;
}
bool querySuccess = true;
entity->forEachSnapshotInTimeRange(
timeOrigin - timeWindow, timeOrigin,
[&](const wm::EntitySnapshot& snapshot)
{
const auto* instance = snapshot.findInstance(instanceIndex);
if (instance)
{
timestampsSec.push_back(
(instance->id().timestamp - timeOrigin).toSecondsDouble());
jointValues.emplace_back(readJointData(*instance->data()));
}
else
{
std::stringstream sstream;
sstream << "Could not find instance with index " << instanceIndex
<< " in snapshot " << snapshot.id() << ".";
result.errorMessage = sstream.str();
querySuccess = false;
}
});
if (querySuccess)
{
latestData = entity->findLatestInstanceData(instanceIndex);
if (latestData)
{
result.success = true;
}
else
{
std::stringstream sstream;
sstream << "Could not find instance with index " << instanceIndex
<< " for entity " << entity->id() << ".";
result.errorMessage = sstream.str();
return;
}
}
info = getLatestSnapshots<server::wm::CoreSegment,
Eigen::VectorXd,
aron::data::DictPtr>(
segmentPtr,
request.snapshotID,
timeOrigin - timeWindow,
timeOrigin,
[](const aron::data::DictPtr& data) { return readJointData(*data); },
[](const aron::data::DictPtr& data) { return data; });
});
if (result.success)
if (info.success)
{
Eigen::VectorXd latestJoints = readJointData(*latestData);
Eigen::VectorXd latestJoints = readJointData(*info.latestValue);
Eigen::VectorXd prediction(latestJoints.size());
if (timestampsSec.size() <= 1)
if (info.timestampsSec.size() <= 1)
{
prediction = latestJoints;
}
......@@ -298,17 +254,23 @@ namespace armarx::armem::server::robot_state::proprioception
{
using simox::math::LinearRegression;
const bool inputOffset = false;
const LinearRegression model =
LinearRegression<Eigen::Dynamic>::Fit(timestampsSec, jointValues, inputOffset);
const LinearRegression model = LinearRegression<Eigen::Dynamic>::Fit(
info.timestampsSec, info.values, inputOffset);
const auto predictionTime = request.snapshotID.timestamp;
prediction = model.predict((predictionTime - timeOrigin).toSecondsDouble());
}
arondto::Proprioception templateData = arondto::Proprioception::FromAron(latestData);
arondto::Proprioception templateData =
arondto::Proprioception::FromAron(info.latestValue);
emplaceJointData(prediction, templateData);
result.success = true;
result.prediction = templateData.toAron();
}
else
{
result.success = false;
result.errorMessage = info.errorMessage;
}
return result;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment