diff --git a/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp b/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp index 291c3f2290d25e69ad233adcde772700393a0cb3..dd568866aa5e82fe6a4bd86e8629c4472510d6e8 100644 --- a/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp +++ b/source/RobotAPI/components/armem/server/ObjectMemory/ObjectMemory.cpp @@ -237,14 +237,14 @@ namespace armarx::armem::server::obj } armem::prediction::data::PredictionResultSeq - ObjectMemory::predict(const armem::prediction::data::PredictionRequestSeq& iceRequests) + ObjectMemory::predict(const armem::prediction::data::PredictionRequestSeq& requests) { - auto requests = armarx::fromIce<std::vector<armem::PredictionRequest>>(iceRequests); + //auto requests = armarx::fromIce<std::vector<armem::PredictionRequest>>(iceRequests); - std::vector<armem::PredictionResult> results = workingMemory().predict(requests); + /*std::vector<armem::PredictionResult> results = workingMemory().predict(requests); auto iceResults = armarx::toIce<armem::prediction::data::PredictionResultSeq>(results); - return iceResults; + return iceResults;*/ std::vector<armem::prediction::data::PredictionResult> results; for (const auto& request : requests) diff --git a/source/RobotAPI/libraries/armem/core/container_maps.h b/source/RobotAPI/libraries/armem/core/container_maps.h index 98c352e2f4e8049c1ffde1d6388ce78c6f2cb533..803e9aa7536569a4bd0ba982a30ecb5f94547de9 100644 --- a/source/RobotAPI/libraries/armem/core/container_maps.h +++ b/source/RobotAPI/libraries/armem/core/container_maps.h @@ -73,6 +73,50 @@ namespace armarx::armem return result; } + /** + * @brief Get the entry in the map for which the returned key is the longest prefix + * of the given key among the keys in the map that satisfy the predicate. + * + * `prefixFunc` is used to successively calculate the prefixes of the given key. + * It must be pure and return an empty optional when there is no shorter + * prefix of the given key (for strings, this would be the case when passed the empty string). + * `predicate` is used to filter for entries that satisfy the desired condition. + * It must be pure. + * + * @param keyValMap the map that contains the key-value-pairs to search + * @param prefixFunc the function that returns the longest non-identical prefix of the key + * @param predicate the predicate to filter entries on + * @param key the key to calculate the prefixes of + * + * @return The iterator pointing to the found entry, or `keyValMap.end()`. + */ + template <typename KeyT, typename ValueT> + typename std::map<KeyT, ValueT>::const_iterator + findEntryWithLongestPrefixAnd( + const std::map<KeyT, ValueT>& keyValMap, + const std::function<std::optional<KeyT>(KeyT&)>& prefixFunc, + const std::function<bool(const KeyT&, const ValueT&)>& predicate, + const KeyT& key) + { + std::optional<KeyT> curKey = key; + + typename std::map<KeyT, ValueT>::const_iterator result = keyValMap.end(); + do + { + auto iterator = keyValMap.find(curKey.value()); + if (iterator != keyValMap.end() && predicate(iterator->first, iterator->second)) + { + result = iterator; + } + else + { + curKey = prefixFunc(curKey.value()); + } + } + while (result == keyValMap.end() and curKey.has_value()); + + return result; + } /** * @brief Accumulate all the values in a map for which the keys are prefixes of the given key. @@ -194,6 +238,23 @@ namespace armarx::armem } + /** + * @brief Find the entry with the most specific key that contains the given ID + * and satisfies the predicate, or `idMap.end()` if no key contains the ID. + * + * @see `detail::findEntryWithLongestPrefixAnd()` + */ + template <typename ValueT> + typename std::map<MemoryID, ValueT>::const_iterator + findMostSpecificEntryContainingIDAnd(const std::map<MemoryID, ValueT>& idMap, + const std::function<bool(const MemoryID&, const ValueT&)>& predicate, + const MemoryID& id) + { + return detail::findEntryWithLongestPrefixAnd<MemoryID, ValueT>( + idMap, &getMemoryIDParent, predicate, id); + } + + /** * @brief Return all values of keys containing the given ID. * diff --git a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp index bdd30e7c2d010862be714ab0dc21704ef5ca8af1..b9abf548b711194f2aac1db5b797043e05ff3a3f 100644 --- a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp +++ b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.cpp @@ -360,6 +360,15 @@ namespace armarx::armem::server } // PREDICTION + prediction::data::PredictionResultSeq + MemoryToIceAdapter::predict(prediction::data::PredictionRequestSeq requests) + { + ARMARX_IMPORTANT << "Dispatching prediction requests."; + auto res = workingMemory->dispatchPredictions( + armarx::fromIce<std::vector<PredictionRequest>>(requests)); + return armarx::toIce<prediction::data::PredictionResultSeq>(res); + } + prediction::data::EngineSupportMap MemoryToIceAdapter::getAvailableEngines() { prediction::data::EngineSupportMap result; diff --git a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h index 96c1e822d6f4150c6e341bd5c0d234f57efec2b5..88e092c272e08388c196f6e5491c0cbadc363192 100644 --- a/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h +++ b/source/RobotAPI/libraries/armem/server/MemoryToIceAdapter.h @@ -53,6 +53,9 @@ namespace armarx::armem::server data::StoreResult store(const armem::data::StoreInput& input); // PREDICTION + prediction::data::PredictionResultSeq + predict(prediction::data::PredictionRequestSeq requests); + prediction::data::EngineSupportMap getAvailableEngines(); public: diff --git a/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp b/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp index 21b2144220e53bb8a281247c8cbb70161fbc7c1e..dd56d4c280c0213bc1b77d362f60d60476bbda66 100644 --- a/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp +++ b/source/RobotAPI/libraries/armem/server/plugins/ReadWritePluginUser.cpp @@ -122,7 +122,9 @@ namespace armarx::armem::server::plugins armem::prediction::data::PredictionResultSeq ReadWritePluginUser::predict(const armem::prediction::data::PredictionRequestSeq& requests) { - armem::prediction::data::PredictionResultSeq result; + ARMARX_IMPORTANT << "Got prediction request."; + return iceAdapter().predict(requests); + /*armem::prediction::data::PredictionResultSeq result; for (const auto& request : requests) { armem::PredictionResult singleResult; @@ -131,7 +133,7 @@ namespace armarx::armem::server::plugins singleResult.prediction = nullptr; result.push_back(singleResult.toIce()); } - return result; + return result;*/ } armem::prediction::data::EngineSupportMap diff --git a/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h b/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h index 22d44938dbdc81c1e2f3a97f57387592616fe62d..6ec3f62898b0ebb7b45dcd5e77fb8fbefbceff84 100644 --- a/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h +++ b/source/RobotAPI/libraries/armem/server/wm/detail/Prediction.h @@ -22,13 +22,210 @@ #pragma once +#include <functional> +#include "ArmarXCore/core/logging/Logging.h" + #include <RobotAPI/libraries/armem/core/MemoryID.h> #include <RobotAPI/libraries/armem/core/Prediction.h> +#include <RobotAPI/libraries/armem/core/base/detail/derived.h> +#include <RobotAPI/libraries/armem/core/base/detail/lookup_mixins.h> +#include <RobotAPI/libraries/armem/core/container_maps.h> namespace armarx::armem::server::wm::detail { + using Predictor = std::function<PredictionResult(const PredictionRequest&)>; + + template <class DerivedT> + class Prediction + { + public: + explicit Prediction(const std::map<std::string, Predictor>& predictors = {}) + { + } + + const std::map<std::string, Predictor>& + predictors() const + { + return _predictors; + } + + void + addPredictor(const std::string& engine, Predictor&& predictor) + { + _predictors.emplace(engine, predictor); + } + + void + setPredictors(const std::map<std::string, Predictor>& predictors) + { + this->_predictors = predictors; + } + + std::vector<PredictionResult> + dispatchPredictions(const std::vector<PredictionRequest>& requests) + { + const MemoryID ownID = base::detail::derived<DerivedT>(this).id(); + std::vector<PredictionResult> results; + for (const auto& request : requests) + { + results.push_back(dispatchTargetedPrediction(request, ownID)); + } + return results; + } + + PredictionResult + dispatchTargetedPrediction(const PredictionRequest& request, const MemoryID& target) + { + PredictionResult result; + result.snapshotID = request.snapshotID; + + MemoryID ownID = base::detail::derived<DerivedT>(this).id(); + if (ownID == target) + { + auto pred = _predictors.find(request.predictionSettings.predictionEngineID); + if (pred != _predictors.end()) + { + ARMARX_IMPORTANT << "Dispatching to self: Running engine " << pred->first; + return pred->second(request); + } + + result.success = false; + std::stringstream sstream; + sstream << "Could not dispatch prediction request for " << request.snapshotID + << " with engine '" << request.predictionSettings.predictionEngineID + << "' in " << ownID << ": Engine not registered."; + result.errorMessage = sstream.str(); + } + else + { + result.success = false; + std::stringstream sstream; + sstream << "Could not dispatch prediction request for " << request.snapshotID + << " to " << target << " from " << ownID; + result.errorMessage = sstream.str(); + } + return result; + } + + private: + std::map<std::string, Predictor> _predictors; // NOLINT + }; + + template <class DerivedT> + class PredictionContainer : public Prediction<DerivedT> + { + public: + using Prediction<DerivedT>::Prediction; + + std::vector<PredictionResult> + dispatchPredictions(const std::vector<PredictionRequest>& requests) + { + const auto& derivedThis = base::detail::derived<DerivedT>(this); + ARMARX_IMPORTANT << "Dispatching predictions at " << derivedThis.id(); + const std::map<MemoryID, std::vector<PredictionEngine>> engines = + derivedThis.getAllPredictionEngines(); + + std::vector<PredictionResult> results; + for (const PredictionRequest& request : requests) + { + ARMARX_IMPORTANT << "Got request for " << request.snapshotID << " with engine " + << request.predictionSettings.predictionEngineID; + PredictionResult result; + result.snapshotID = request.snapshotID; + + auto iter = + armem::findMostSpecificEntryContainingIDAnd<std::vector<PredictionEngine>>( + engines, + [&request](const MemoryID& /*unused*/, + const std::vector<PredictionEngine>& supported) -> bool + { + return std::find_if( + supported.begin(), + supported.end(), + [&request](const PredictionEngine& engine) { + return engine.engineID == + request.predictionSettings.predictionEngineID; + }) != supported.end(); + }, + request.snapshotID); + + if (iter != engines.end()) + { + const MemoryID& responsibleID = iter->first; + const std::vector<PredictionEngine>& supportedEngines = iter->second; + + ARMARX_IMPORTANT << "Found responsible memory item: " << responsibleID; + result = dispatchTargetedPrediction(request, responsibleID); + ARMARX_IMPORTANT << "Got result with " << result.success << ", " + << result.snapshotID << ", and " << result.prediction + << " with " << result.errorMessage; + } + else + { + result.success = false; + std::stringstream sstream; + sstream << "Could not find segment offering prediction engine '" + << request.predictionSettings.predictionEngineID << "' for memory ID " + << request.snapshotID << "."; + result.errorMessage = sstream.str(); + } + results.push_back(result); + } + return results; + } + + PredictionResult + dispatchTargetedPrediction(const PredictionRequest& request, const MemoryID& target) + { + PredictionResult result; + result.snapshotID = request.snapshotID; + const auto& derivedThis = base::detail::derived<DerivedT>(this); + MemoryID ownID = derivedThis.id(); + if (ownID == target) + { + ARMARX_IMPORTANT << "Dispatching to self."; + result = Prediction<DerivedT>::dispatchTargetedPrediction(request, target); + } + else if (contains(ownID, target)) + { + std::string childName = *(target.getItems().begin() + + static_cast<long>(ownID.getItems().size())); // NOLINT + // TODO(phesch): Looping over all the children just to find the one + // with the right name isn't nice, but it's the interface we've got. + typename DerivedT::ChildT* child = nullptr; + derivedThis.forEachChild( + [&child, &childName](auto& otherChild) + { + if (otherChild.name() == childName) + { + child = &otherChild; + } + }); + if (child) + { + result = child->dispatchTargetedPrediction(request, target); + } + else + { + result.success = false; + std::stringstream sstream; + sstream << "Could not find memory item with ID " << target; + result.errorMessage = sstream.str(); + } + } + else + { + result.success = false; + std::stringstream sstream; + sstream << "Could not dispatch prediction request for " << request.snapshotID + << " to " << target << " from " << ownID; + result.errorMessage = sstream.str(); + } + return result; + } + }; -} // namespace armarx::armem::base::detail +} // namespace armarx::armem::server::wm::detail diff --git a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp index a734501b1608e63b794d72032e48a40145b917db..973e3ca5c21f1913f5abb84718789b4d4d27b2d2 100644 --- a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp +++ b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.cpp @@ -109,109 +109,4 @@ namespace armarx::armem::server::wm result.memoryUpdateType = UpdateType::UpdatedExisting; return result; } - - std::vector<PredictionResult> Memory::predict(const std::vector<PredictionRequest>& requests) - { - const std::map<MemoryID, std::vector<PredictionEngine>> engines = getAllPredictionEngines(); - - std::vector<PredictionResult> results; - for (const PredictionRequest& request : requests) - { - PredictionResult& result = results.emplace_back(); - result.snapshotID = request.snapshotID; - - /* - * Subproblem: Find entry in engines that ... - * - contains the snapshot ID, - * - supports the requested einge - * - and is most specific. - */ - - auto it = armem::findMostSpecificEntryContainingIDIf( - engines, request.snapshotID, - [&request](const std::vector<PredictionEngine>& supported) - { - return std::find(supported.begin(), supported.end(), request.predictionSettings.predictionEngineID) != supported.end(); - }); - if (it != engines.end()) - { - const MemoryID& responsibleID = it->first; - const std::vector<PredictionEngine>& supportedEngines = it->second; - - // TODO: Get container and let it do the prediction. - - if (armem::contains(id().withCoreSegmentName("Instance"), request.snapshotID) - and not request.snapshotID.hasGap() - and request.snapshotID.hasTimestamp()) - { - objpose::ObjectPosePredictionRequest objPoseRequest; - toIce(objPoseRequest.timeWindow, Duration::SecondsDouble(predictionTimeWindow)); - objPoseRequest.objectID = toIce(ObjectID(request.snapshotID.entityName)); - objPoseRequest.settings = request.settings; - toIce(objPoseRequest.timestamp, request.snapshotID.timestamp); - - objpose::ObjectPosePredictionResult objPoseResult = - predictObjectPoses({objPoseRequest}).at(0); - result.success = objPoseResult.success; - result.errorMessage = objPoseResult.errorMessage; - - if (objPoseResult.success) - { - armem::client::QueryBuilder builder; - builder.latestEntitySnapshot(request.snapshotID); - auto queryResult = armarx::fromIce<armem::client::QueryResult>( - query(builder.buildQueryInputIce())); - std::string instanceError = - "Could not find instance '" + request.snapshotID.str() + "' in memory"; - if (!queryResult.success) - { - result.success = false; - result.errorMessage << instanceError << ":\n" << queryResult.errorMessage; - } - else - { - if (not request.snapshotID.hasInstanceIndex()) - { - request.snapshotID.instanceIndex = 0; - } - auto* aronInstance = queryResult.memory.findLatestInstance( - request.snapshotID, request.snapshotID.instanceIndex); - if (aronInstance == nullptr) - { - result.success = false; - result.errorMessage << instanceError << ": No latest instance found."; - } - else - { - auto instance = - armem::arondto::ObjectInstance::FromAron(aronInstance->data()); - objpose::toAron( - instance.pose, - armarx::fromIce<objpose::ObjectPose>(objPoseResult.prediction)); - result.prediction = instance.toAron(); - } - } - } - } - else - { - result.success = false; - result.errorMessage << "No predictions are supported for MemoryID " - << request.snapshotID - << ". Have you given an instance index if requesting" - << " an object pose prediction?"; - } - } - else - { - result.success = false; - result.errorMessage << "No predictions are supported for snapshot ID " - << request.snapshotID << "."; - result.prediction = nullptr; - } - } - - return results; - } - } diff --git a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h index bf06628e0fe0eab20813e1d2a454e78c0cfe457b..35015353a2653ca957bb19fe29c7e8430790885d 100644 --- a/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h +++ b/source/RobotAPI/libraries/armem/server/wm/memory_definitions.h @@ -61,6 +61,7 @@ namespace armarx::armem::server::wm public base::ProviderSegmentBase<Entity, ProviderSegment> , public detail::MaxHistorySizeParent<ProviderSegment> , public armem::wm::detail::FindInstanceDataMixin<ProviderSegment> + , public armem::server::wm::detail::Prediction<ProviderSegment> { public: @@ -83,9 +84,10 @@ namespace armarx::armem::server::wm /// @brief base::CoreSegmentBase class CoreSegment : - public base::CoreSegmentBase<ProviderSegment, CoreSegment>, - public detail::MaxHistorySizeParent<CoreSegment> + public base::CoreSegmentBase<ProviderSegment, CoreSegment> + , public detail::MaxHistorySizeParent<CoreSegment> , public armem::wm::detail::FindInstanceDataMixin<CoreSegment> + , public armem::server::wm::detail::PredictionContainer<CoreSegment> { using Base = base::CoreSegmentBase<ProviderSegment, CoreSegment>; @@ -126,6 +128,7 @@ namespace armarx::armem::server::wm class Memory : public base::MemoryBase<CoreSegment, Memory> , public armem::wm::detail::FindInstanceDataMixin<Memory> + , public armem::server::wm::detail::PredictionContainer<Memory> { using Base = base::MemoryBase<CoreSegment, Memory>; @@ -147,11 +150,6 @@ namespace armarx::armem::server::wm */ Base::UpdateResult updateLocking(const EntityUpdate& update); - - std::vector<PredictionResult> - predict(const std::vector<PredictionRequest>& requests); - - }; }