From b6e004cdc0dcf0e3310df0d188fb55e0b06168d7 Mon Sep 17 00:00:00 2001
From: phesch <ulila@student.kit.edu>
Date: Mon, 27 Jun 2022 20:52:03 +0200
Subject: [PATCH] Add prediction dispatch to server memory defs

---
 .../server/ObjectMemory/ObjectMemory.cpp      |   8 +-
 .../libraries/armem/core/container_maps.h     |  61 ++++++
 .../armem/server/MemoryToIceAdapter.cpp       |   9 +
 .../armem/server/MemoryToIceAdapter.h         |   3 +
 .../server/plugins/ReadWritePluginUser.cpp    |   6 +-
 .../armem/server/wm/detail/Prediction.h       | 199 +++++++++++++++++-
 .../armem/server/wm/memory_definitions.cpp    | 105 ---------
 .../armem/server/wm/memory_definitions.h      |  12 +-
 8 files changed, 284 insertions(+), 119 deletions(-)

diff --git a/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp b/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp
index 291c3f229..dd568866a 100644
--- a/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp
+++ b/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp
@@ -237,14 +237,14 @@ namespace armarx::armem::server::obj
     }
 
     armem::prediction::data::PredictionResultSeq
-    ObjectMemory::predict(const armem::prediction::data::PredictionRequestSeq& iceRequests)
+    ObjectMemory::predict(const armem::prediction::data::PredictionRequestSeq& requests)
     {
-        auto requests = armarx::fromIce<std::vector<armem::PredictionRequest>>(iceRequests);
+        //auto requests = armarx::fromIce<std::vector<armem::PredictionRequest>>(iceRequests);
 
-        std::vector<armem::PredictionResult> results = workingMemory().predict(requests);
+        /*std::vector<armem::PredictionResult> results = workingMemory().predict(requests);
 
         auto iceResults = armarx::toIce<armem::prediction::data::PredictionResultSeq>(results);
-        return iceResults;
+        return iceResults;*/
 
         std::vector<armem::prediction::data::PredictionResult> results;
         for (const auto& request : requests)
diff --git a/source/RobotAPI/libraries/armem/core/container_maps.h b/source/RobotAPI/libraries/armem/core/container_maps.h
index 98c352e2f..803e9aa75 100644
--- a/source/RobotAPI/libraries/armem/core/container_maps.h
+++ b/source/RobotAPI/libraries/armem/core/container_maps.h
@@ -73,6 +73,50 @@ namespace armarx::armem
             return result;
         }
 
+        /**
+        * @brief Get the entry in the map for which the returned key is the longest prefix
+        *        of the given key among the keys in the map that satisfy the predicate.
+        *
+        * `prefixFunc` is used to successively calculate the prefixes of the given key.
+        * It must be pure and return an empty optional when there is no shorter
+        * prefix of the given key (for strings, this would be the case when passed the empty string).
+        * `predicate` is used to filter for entries that satisfy the desired condition.
+        * It must be pure.
+        *
+        * @param keyValMap the map that contains the key-value-pairs to search
+        * @param prefixFunc the function that returns the longest non-identical prefix of the key
+        * @param predicate the predicate to filter entries on
+        * @param key the key to calculate the prefixes of
+        *
+        * @return The iterator pointing to the found entry, or `keyValMap.end()`.
+        */
+        template <typename KeyT, typename ValueT>
+        typename std::map<KeyT, ValueT>::const_iterator
+        findEntryWithLongestPrefixAnd(
+            const std::map<KeyT, ValueT>& keyValMap,
+            const std::function<std::optional<KeyT>(KeyT&)>& prefixFunc,
+            const std::function<bool(const KeyT&, const ValueT&)>& predicate,
+            const KeyT& key)
+        {
+            std::optional<KeyT> curKey = key;
+
+            typename std::map<KeyT, ValueT>::const_iterator result = keyValMap.end();
+            do
+            {
+                auto iterator = keyValMap.find(curKey.value());
+                if (iterator != keyValMap.end() && predicate(iterator->first, iterator->second))
+                {
+                    result = iterator;
+                }
+                else
+                {
+                    curKey = prefixFunc(curKey.value());
+                }
+            }
+            while (result == keyValMap.end() and curKey.has_value());
+
+            return result;
+        }
 
         /**
         * @brief Accumulate all the values in a map for which the keys are prefixes of the given key.
@@ -194,6 +238,23 @@ namespace armarx::armem
     }
 
 
+    /**
+     * @brief Find the entry with the most specific key that contains the given ID
+     * and satisfies the predicate, or `idMap.end()` if no key contains the ID.
+     *
+     * @see `detail::findEntryWithLongestPrefixAnd()`
+     */
+    template <typename ValueT>
+    typename std::map<MemoryID, ValueT>::const_iterator
+    findMostSpecificEntryContainingIDAnd(const std::map<MemoryID, ValueT>& idMap,
+                                         const std::function<bool(const MemoryID&, const ValueT&)>& predicate,
+                                         const MemoryID& id)
+    {
+        return detail::findEntryWithLongestPrefixAnd<MemoryID, ValueT>(
+            idMap, &getMemoryIDParent, predicate, id);
+    }
+
+
     /**
      * @brief Return all values of keys containing the given ID.
      *
diff --git a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp
index bdd30e7c2..b9abf548b 100644
--- a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp
+++ b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp
@@ -360,6 +360,15 @@ namespace armarx::armem::server
     }
 
     // PREDICTION
+    prediction::data::PredictionResultSeq
+    MemoryToIceAdapter::predict(prediction::data::PredictionRequestSeq requests)
+    {
+        ARMARX_IMPORTANT << "Dispatching prediction requests.";
+        auto res = workingMemory->dispatchPredictions(
+            armarx::fromIce<std::vector<PredictionRequest>>(requests));
+        return armarx::toIce<prediction::data::PredictionResultSeq>(res);
+    }
+
     prediction::data::EngineSupportMap MemoryToIceAdapter::getAvailableEngines()
     {
         prediction::data::EngineSupportMap result;
diff --git a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h
index 96c1e822d..88e092c27 100644
--- a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h
+++ b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h
@@ -53,6 +53,9 @@ namespace armarx::armem::server
         data::StoreResult store(const armem::data::StoreInput& input);
 
         // PREDICTION
+        prediction::data::PredictionResultSeq
+        predict(prediction::data::PredictionRequestSeq requests);
+
         prediction::data::EngineSupportMap getAvailableEngines();
 
     public:
diff --git a/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp b/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp
index 21b214422..dd56d4c28 100644
--- a/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp
+++ b/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp
@@ -122,7 +122,9 @@ namespace armarx::armem::server::plugins
     armem::prediction::data::PredictionResultSeq
     ReadWritePluginUser::predict(const armem::prediction::data::PredictionRequestSeq& requests)
     {
-        armem::prediction::data::PredictionResultSeq result;
+        ARMARX_IMPORTANT << "Got prediction request.";
+        return iceAdapter().predict(requests);
+        /*armem::prediction::data::PredictionResultSeq result;
         for (const auto& request : requests)
         {
             armem::PredictionResult singleResult;
@@ -131,7 +133,7 @@ namespace armarx::armem::server::plugins
             singleResult.prediction = nullptr;
             result.push_back(singleResult.toIce());
         }
-        return result;
+        return result;*/
     }
 
     armem::prediction::data::EngineSupportMap
diff --git a/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h b/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h
index 22d44938d..6ec3f6289 100644
--- a/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h
+++ b/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h
@@ -22,13 +22,210 @@
 
 #pragma once
 
+#include <functional>
+#include "ArmarXCore/core/logging/Logging.h"
+
 #include <RobotAPI/libraries/armem/core/MemoryID.h>
 #include <RobotAPI/libraries/armem/core/Prediction.h>
+#include <RobotAPI/libraries/armem/core/base/detail/derived.h>
+#include <RobotAPI/libraries/armem/core/base/detail/lookup_mixins.h>
+#include <RobotAPI/libraries/armem/core/container_maps.h>
 
 
 namespace armarx::armem::server::wm::detail
 {
 
+    using Predictor = std::function<PredictionResult(const PredictionRequest&)>;
+
+    template <class DerivedT>
+    class Prediction
+    {
+    public:
+        explicit Prediction(const std::map<std::string, Predictor>& predictors = {})
+        {
+        }
+
+        const std::map<std::string, Predictor>&
+        predictors() const
+        {
+            return _predictors;
+        }
+
+        void
+        addPredictor(const std::string& engine, Predictor&& predictor)
+        {
+            _predictors.emplace(engine, predictor);
+        }
+
+        void
+        setPredictors(const std::map<std::string, Predictor>& predictors)
+        {
+            this->_predictors = predictors;
+        }
+
+        std::vector<PredictionResult>
+        dispatchPredictions(const std::vector<PredictionRequest>& requests)
+        {
+            const MemoryID ownID = base::detail::derived<DerivedT>(this).id();
+            std::vector<PredictionResult> results;
+            for (const auto& request : requests)
+            {
+                results.push_back(dispatchTargetedPrediction(request, ownID));
+            }
+            return results;
+        }
+
+        PredictionResult
+        dispatchTargetedPrediction(const PredictionRequest& request, const MemoryID& target)
+        {
+            PredictionResult result;
+            result.snapshotID = request.snapshotID;
+
+            MemoryID ownID = base::detail::derived<DerivedT>(this).id();
+            if (ownID == target)
+            {
+                auto pred = _predictors.find(request.predictionSettings.predictionEngineID);
+                if (pred != _predictors.end())
+                {
+                    ARMARX_IMPORTANT << "Dispatching to self: Running engine " << pred->first;
+                    return pred->second(request);
+                }
+
+                result.success = false;
+                std::stringstream sstream;
+                sstream << "Could not dispatch prediction request for " << request.snapshotID
+                        << " with engine '" << request.predictionSettings.predictionEngineID
+                        << "' in " << ownID << ": Engine not registered.";
+                result.errorMessage = sstream.str();
+            }
+            else
+            {
+                result.success = false;
+                std::stringstream sstream;
+                sstream << "Could not dispatch prediction request for " << request.snapshotID
+                        << " to " << target << " from " << ownID;
+                result.errorMessage = sstream.str();
+            }
+            return result;
+        }
+
+    private:
+        std::map<std::string, Predictor> _predictors; // NOLINT
+    };
+
+    template <class DerivedT>
+    class PredictionContainer : public Prediction<DerivedT>
+    {
+    public:
+        using Prediction<DerivedT>::Prediction;
+
+        std::vector<PredictionResult>
+        dispatchPredictions(const std::vector<PredictionRequest>& requests)
+        {
+            const auto& derivedThis = base::detail::derived<DerivedT>(this);
+            ARMARX_IMPORTANT << "Dispatching predictions at " << derivedThis.id();
+            const std::map<MemoryID, std::vector<PredictionEngine>> engines =
+                derivedThis.getAllPredictionEngines();
+
+            std::vector<PredictionResult> results;
+            for (const PredictionRequest& request : requests)
+            {
+                ARMARX_IMPORTANT << "Got request for " << request.snapshotID << " with engine "
+                          << request.predictionSettings.predictionEngineID;
+                PredictionResult result;
+                result.snapshotID = request.snapshotID;
+
+                auto iter =
+                    armem::findMostSpecificEntryContainingIDAnd<std::vector<PredictionEngine>>(
+                        engines,
+                        [&request](const MemoryID& /*unused*/,
+                                   const std::vector<PredictionEngine>& supported) -> bool
+                        {
+                            return std::find_if(
+                                       supported.begin(),
+                                       supported.end(),
+                                       [&request](const PredictionEngine& engine) {
+                                           return engine.engineID ==
+                                                  request.predictionSettings.predictionEngineID;
+                                       }) != supported.end();
+                        },
+                        request.snapshotID);
+
+                if (iter != engines.end())
+                {
+                    const MemoryID& responsibleID = iter->first;
+                    const std::vector<PredictionEngine>& supportedEngines = iter->second;
+
+                    ARMARX_IMPORTANT << "Found responsible memory item: " << responsibleID;
+                    result = dispatchTargetedPrediction(request, responsibleID);
+                    ARMARX_IMPORTANT << "Got result with " << result.success << ", "
+                                     << result.snapshotID << ", and " << result.prediction
+                                     << " with " << result.errorMessage;
+                }
+                else
+                {
+                    result.success = false;
+                    std::stringstream sstream;
+                    sstream << "Could not find segment offering prediction engine '"
+                            << request.predictionSettings.predictionEngineID << "' for memory ID "
+                            << request.snapshotID << ".";
+                    result.errorMessage = sstream.str();
+                }
+                results.push_back(result);
+            }
+            return results;
+        }
+
+        PredictionResult
+        dispatchTargetedPrediction(const PredictionRequest& request, const MemoryID& target)
+        {
+            PredictionResult result;
+            result.snapshotID = request.snapshotID;
 
+            const auto& derivedThis = base::detail::derived<DerivedT>(this);
+            MemoryID ownID = derivedThis.id();
+            if (ownID == target)
+            {
+                ARMARX_IMPORTANT << "Dispatching to self.";
+                result = Prediction<DerivedT>::dispatchTargetedPrediction(request, target);
+            }
+            else if (contains(ownID, target))
+            {
+                std::string childName = *(target.getItems().begin() +
+                                          static_cast<long>(ownID.getItems().size())); // NOLINT
+                // TODO(phesch): Looping over all the children just to find the one
+                //               with the right name isn't nice, but it's the interface we've got.
+                typename DerivedT::ChildT* child = nullptr;
+                derivedThis.forEachChild(
+                    [&child, &childName](auto& otherChild)
+                    {
+                        if (otherChild.name() == childName)
+                        {
+                            child = &otherChild;
+                        }
+                    });
+                if (child)
+                {
+                    result = child->dispatchTargetedPrediction(request, target);
+                }
+                else
+                {
+                    result.success = false;
+                    std::stringstream sstream;
+                    sstream << "Could not find memory item with ID " << target;
+                    result.errorMessage = sstream.str();
+                }
+            }
+            else
+            {
+                result.success = false;
+                std::stringstream sstream;
+                sstream << "Could not dispatch prediction request for " << request.snapshotID
+                        << " to " << target << " from " << ownID;
+                result.errorMessage = sstream.str();
+            }
+            return result;
+        }
+    };
 
-} // namespace armarx::armem::base::detail
+} // namespace armarx::armem::server::wm::detail
diff --git a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp
index a734501b1..973e3ca5c 100644
--- a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp
+++ b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp
@@ -109,109 +109,4 @@ namespace armarx::armem::server::wm
         result.memoryUpdateType = UpdateType::UpdatedExisting;
         return result;
     }
-
-    std::vector<PredictionResult> Memory::predict(const std::vector<PredictionRequest>& requests)
-    {
-        const std::map<MemoryID, std::vector<PredictionEngine>> engines = getAllPredictionEngines();
-
-        std::vector<PredictionResult> results;
-        for (const PredictionRequest& request : requests)
-        {
-            PredictionResult& result = results.emplace_back();
-            result.snapshotID = request.snapshotID;
-
-            /*
-             * Subproblem: Find entry in engines that ...
-             * - contains the snapshot ID,
-             * - supports the requested einge
-             * - and is most specific.
-             */
-
-            auto it = armem::findMostSpecificEntryContainingIDIf(
-                        engines, request.snapshotID,
-                        [&request](const std::vector<PredictionEngine>& supported)
-            {
-                return std::find(supported.begin(), supported.end(), request.predictionSettings.predictionEngineID) != supported.end();
-            });
-            if (it != engines.end())
-            {
-                const MemoryID& responsibleID = it->first;
-                const std::vector<PredictionEngine>& supportedEngines = it->second;
-
-                // TODO: Get container and let it do the prediction.
-
-                if (armem::contains(id().withCoreSegmentName("Instance"), request.snapshotID)
-                    and not request.snapshotID.hasGap()
-                    and request.snapshotID.hasTimestamp())
-                {
-                    objpose::ObjectPosePredictionRequest objPoseRequest;
-                    toIce(objPoseRequest.timeWindow, Duration::SecondsDouble(predictionTimeWindow));
-                    objPoseRequest.objectID = toIce(ObjectID(request.snapshotID.entityName));
-                    objPoseRequest.settings = request.settings;
-                    toIce(objPoseRequest.timestamp, request.snapshotID.timestamp);
-
-                    objpose::ObjectPosePredictionResult objPoseResult =
-                            predictObjectPoses({objPoseRequest}).at(0);
-                    result.success = objPoseResult.success;
-                    result.errorMessage = objPoseResult.errorMessage;
-
-                    if (objPoseResult.success)
-                    {
-                        armem::client::QueryBuilder builder;
-                        builder.latestEntitySnapshot(request.snapshotID);
-                        auto queryResult = armarx::fromIce<armem::client::QueryResult>(
-                                    query(builder.buildQueryInputIce()));
-                        std::string instanceError =
-                                "Could not find instance '" + request.snapshotID.str() + "' in memory";
-                        if (!queryResult.success)
-                        {
-                            result.success = false;
-                            result.errorMessage << instanceError << ":\n" << queryResult.errorMessage;
-                        }
-                        else
-                        {
-                            if (not request.snapshotID.hasInstanceIndex())
-                            {
-                                request.snapshotID.instanceIndex = 0;
-                            }
-                            auto* aronInstance = queryResult.memory.findLatestInstance(
-                                        request.snapshotID, request.snapshotID.instanceIndex);
-                            if (aronInstance == nullptr)
-                            {
-                                result.success = false;
-                                result.errorMessage << instanceError << ": No latest instance found.";
-                            }
-                            else
-                            {
-                                auto instance =
-                                        armem::arondto::ObjectInstance::FromAron(aronInstance->data());
-                                objpose::toAron(
-                                            instance.pose,
-                                            armarx::fromIce<objpose::ObjectPose>(objPoseResult.prediction));
-                                result.prediction = instance.toAron();
-                            }
-                        }
-                    }
-                }
-                else
-                {
-                    result.success = false;
-                    result.errorMessage << "No predictions are supported for MemoryID "
-                                        << request.snapshotID
-                                        << ". Have you given an instance index if requesting"
-                                        << " an object pose prediction?";
-                }
-            }
-            else
-            {
-                result.success = false;
-                result.errorMessage << "No predictions are supported for snapshot ID "
-                                    << request.snapshotID << ".";
-                result.prediction = nullptr;
-            }
-        }
-
-        return results;
-    }
-
 }
diff --git a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h
index bf06628e0..35015353a 100644
--- a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h
+++ b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h
@@ -61,6 +61,7 @@ namespace armarx::armem::server::wm
         public base::ProviderSegmentBase<Entity, ProviderSegment>
         , public detail::MaxHistorySizeParent<ProviderSegment>
         , public armem::wm::detail::FindInstanceDataMixin<ProviderSegment>
+        , public armem::server::wm::detail::Prediction<ProviderSegment>
     {
     public:
 
@@ -83,9 +84,10 @@ namespace armarx::armem::server::wm
 
     /// @brief base::CoreSegmentBase
     class CoreSegment :
-        public base::CoreSegmentBase<ProviderSegment, CoreSegment>,
-        public detail::MaxHistorySizeParent<CoreSegment>
+        public base::CoreSegmentBase<ProviderSegment, CoreSegment>
+        , public detail::MaxHistorySizeParent<CoreSegment>
         , public armem::wm::detail::FindInstanceDataMixin<CoreSegment>
+        , public armem::server::wm::detail::PredictionContainer<CoreSegment>
     {
         using Base = base::CoreSegmentBase<ProviderSegment, CoreSegment>;
 
@@ -126,6 +128,7 @@ namespace armarx::armem::server::wm
     class Memory :
         public base::MemoryBase<CoreSegment, Memory>
         , public armem::wm::detail::FindInstanceDataMixin<Memory>
+        , public armem::server::wm::detail::PredictionContainer<Memory>
     {
         using Base = base::MemoryBase<CoreSegment, Memory>;
 
@@ -147,11 +150,6 @@ namespace armarx::armem::server::wm
          */
         Base::UpdateResult updateLocking(const EntityUpdate& update);
 
-
-        std::vector<PredictionResult>
-        predict(const std::vector<PredictionRequest>& requests);
-
-
     };
 
 }
-- 
GitLab