From 854e838b143869a64a7ed7b3d5a578ba37c0ad3f Mon Sep 17 00:00:00 2001
From: phesch <ulila@student.kit.edu>
Date: Wed, 18 May 2022 20:29:15 +0200
Subject: [PATCH] ObjectMemory Visu: Predictions with fewer locks

---
 .../server/instance/SegmentAdapter.cpp        | 23 +++++++++-
 .../armem_objects/server/instance/Visu.cpp    | 44 ++++++++++++++-----
 .../armem_objects/server/instance/Visu.h      |  5 ++-
 3 files changed, 57 insertions(+), 15 deletions(-)

diff --git a/source/RobotAPI/libraries/armem_objects/server/instance/SegmentAdapter.cpp b/source/RobotAPI/libraries/armem_objects/server/instance/SegmentAdapter.cpp
index 010263ceb..11a90d59b 100644
--- a/source/RobotAPI/libraries/armem_objects/server/instance/SegmentAdapter.cpp
+++ b/source/RobotAPI/libraries/armem_objects/server/instance/SegmentAdapter.cpp
@@ -94,6 +94,27 @@ namespace armarx::armem::server::obj::instance
         robotHead.fetchDatafields();
 
         visu.arviz = arviz;
+
+        visu.getPoseHistory =
+            [this](const objpose::ObjectPoseSeq& objectPoses, const Duration& timeWindow)
+        {
+            std::vector<std::map<DateTime, objpose::ObjectPose>> poseHistories;
+            segment.doLocked(
+                [this, &objectPoses, &timeWindow, &poseHistories]()
+                {
+                    for (const auto& objectPose : objectPoses)
+                    {
+                        const wm::Entity* entity = segment.findObjectEntity(objectPose.objectID);
+                        if (entity != nullptr)
+                        {
+                            poseHistories.push_back(instance::Segment::getObjectPosesInRange(
+                                *entity, Time::Now() - timeWindow, Time::Invalid()));
+                        }
+                    }
+                });
+            return poseHistories;
+        };
+
         if (!visu.updateTask)
         {
             visu.updateTask = new SimpleRunningTask<>([this]()
@@ -103,8 +124,6 @@ namespace armarx::armem::server::obj::instance
             visu.updateTask->start();
         }
 
-        visu.predictor = [this](const objpose::ObjectPosePredictionRequest& request)
-        { return predictObjectPoses({request}).at(0); };
 
         segment.connect(arviz);
     }
diff --git a/source/RobotAPI/libraries/armem_objects/server/instance/Visu.cpp b/source/RobotAPI/libraries/armem_objects/server/instance/Visu.cpp
index 5f1a8c0ae..4f11adc31 100644
--- a/source/RobotAPI/libraries/armem_objects/server/instance/Visu.cpp
+++ b/source/RobotAPI/libraries/armem_objects/server/instance/Visu.cpp
@@ -9,6 +9,8 @@
 
 #include <RobotAPI/libraries/ArmarXObjects/ice_conversions.h>
 #include <RobotAPI/libraries/ArmarXObjects/ObjectFinder.h>
+#include <RobotAPI/libraries/ArmarXObjects/predictions.h>
+#include <RobotAPI/libraries/armem/client/Prediction.h>
 
 
 namespace armarx::armem::server::obj::instance
@@ -76,9 +78,14 @@ namespace armarx::armem::server::obj::instance
         const ObjectFinder& objectFinder) const
     {
         std::map<std::string, viz::Layer> stage;
-        for (const objpose::ObjectPose& objectPose : objectPoses)
+        auto poseHistories =
+            getPoseHistory(objectPoses, Duration::SecondsDouble(linearPredictionTimeWindowSeconds));
+        for (size_t i = 0; i < objectPoses.size(); ++i)
         {
-            visualizeObjectPose(getLayer(objectPose.providerName, stage), objectPose, objectFinder);
+            visualizeObjectPose(getLayer(objectPoses.at(i).providerName, stage),
+                                objectPoses.at(i),
+                                poseHistories.at(i),
+                                objectFinder);
         }
         return simox::alg::get_values(stage);
     }
@@ -89,9 +96,19 @@ namespace armarx::armem::server::obj::instance
         const ObjectFinder& objectFinder) const
     {
         std::map<std::string, viz::Layer> stage;
+        objpose::ObjectPoseSeq poses;
         for (const auto& [id, objectPose] : objectPoses)
         {
-            visualizeObjectPose(getLayer(objectPose.providerName, stage), objectPose, objectFinder);
+            poses.push_back(objectPose);
+        }
+        auto poseHistories =
+            getPoseHistory(poses, Duration::SecondsDouble(linearPredictionTimeWindowSeconds));
+        for (size_t i = 0; i < poses.size(); ++i)
+        {
+            visualizeObjectPose(getLayer(poses.at(i).providerName, stage),
+                                poses.at(i),
+                                poseHistories.at(i),
+                                objectFinder);
         }
         return simox::alg::get_values(stage);
     }
@@ -116,9 +133,11 @@ namespace armarx::armem::server::obj::instance
         const ObjectFinder& objectFinder) const
     {
         viz::Layer layer = arviz.layer(providerName);
-        for (const objpose::ObjectPose& objectPose : objectPoses)
+        auto poseHistories =
+            getPoseHistory(objectPoses, Duration::SecondsDouble(linearPredictionTimeWindowSeconds));
+        for (size_t i = 0; i < poseHistories.size(); ++i)
         {
-            visualizeObjectPose(layer, objectPose, objectFinder);
+            visualizeObjectPose(layer, objectPoses.at(i), poseHistories.at(i), objectFinder);
         }
         return layer;
     }
@@ -126,6 +145,7 @@ namespace armarx::armem::server::obj::instance
     void Visu::visualizeObjectPose(
         viz::Layer& layer,
         const objpose::ObjectPose& objectPose,
+        const std::map<DateTime, objpose::ObjectPose>& poseHistory,
         const ObjectFinder& objectFinder) const
     {
         const bool show = objectPose.confidence >= minConfidence;
@@ -248,13 +268,13 @@ namespace armarx::armem::server::obj::instance
         }
         if (showLinearPredictions)
         {
-            objpose::ObjectPosePredictionRequest request;
-            toIce(request.objectID, id);
-            request.settings.predictionEngineID = "Linear Position Regression";
-            toIce(request.timeWindow, Duration::SecondsDouble(linearPredictionTimeWindowSeconds));
-            toIce(request.timestamp,
-                  Time::Now() + Duration::SecondsDouble(linearPredictionTimeOffsetSeconds));
-            auto predictionResult = predictor(request);
+            armem::client::PredictionSettings settings;
+            settings.predictionEngineID = "Linear Position Regression";
+            auto predictionResult = objpose::predictObjectPose(
+                poseHistory,
+                Time::Now() + Duration::SecondsDouble(linearPredictionTimeOffsetSeconds),
+                objectPose,
+                settings);
             if (predictionResult.success)
             {
                 auto predictedPose =
diff --git a/source/RobotAPI/libraries/armem_objects/server/instance/Visu.h b/source/RobotAPI/libraries/armem_objects/server/instance/Visu.h
index 56380159b..573aad790 100644
--- a/source/RobotAPI/libraries/armem_objects/server/instance/Visu.h
+++ b/source/RobotAPI/libraries/armem_objects/server/instance/Visu.h
@@ -53,6 +53,7 @@ namespace armarx::armem::server::obj::instance
         void visualizeObjectPose(
             viz::Layer& layer,
             const objpose::ObjectPose& objectPose,
+            const std::map<DateTime, objpose::ObjectPose>& poseHistory,
             const ObjectFinder& objectFinder
         ) const;
 
@@ -94,7 +95,9 @@ namespace armarx::armem::server::obj::instance
 
         SimpleRunningTask<>::pointer_type updateTask;
 
-        std::function<objpose::ObjectPosePredictionResult(const objpose::ObjectPosePredictionRequest&)> predictor;
+        std::function<std::vector<std::map<DateTime, objpose::ObjectPose>>
+                      (const objpose::ObjectPoseSeq&, const Duration& timeWindow)>
+            getPoseHistory;
 
 
         struct RemoteGui
-- 
GitLab