diff --git a/python/skills/skills/episodic_verbalization.py b/python/skills/skills/episodic_verbalization.py index 6ce064f8145fe3363e07081f6fa5056bb43b4c6f..d57afc95bc7cbf23bc2d7051baebb0940016e486 100644 --- a/python/skills/skills/episodic_verbalization.py +++ b/python/skills/skills/episodic_verbalization.py @@ -42,9 +42,7 @@ class EM_Verbalization_Tree: self.tree = HigherLevelSummary('', children=[tree]) self.ltm_base_path = self.ltm_base_path - with open(self.history_cache, 'wb') as file: - pickle.dump(self.tree) - + self.history_cache.write_bytes(pickle.dumps(self.tree)) print("made tree") def update_tree(self): @@ -52,8 +50,8 @@ class EM_Verbalization_Tree: llm = instantiate_llm(llm_option_dict) tree = extend_existing_history_from_memory_snapshots(self.tree, Path(self.ltm_base_path), llm, LLMBasedSummarizer(llm, example_db_name='armarx_lt_mem')) self.tree = tree - with open(self.history_cache, 'wb') as file: - pickle.dump(self.tree) + self.history_cache.write_bytes(pickle.dumps(self.tree)) + print("updated tree") def answer_question_with_tree(self, question: str, config_file_name: str= "armarx_lt_mem/full", llm_option_dict:str= "{'type': 'ChatOpenAI', 'model_name': 'gpt-4o-mini', 'request_timeout': 30, 'max_retries': 2}"): print("Triggered question answering...") diff --git a/python/skills/skills/episodic_verbalization_core/data/armarx_lt_mem/2024-09-18-a7a-predef.pkl b/python/skills/skills/episodic_verbalization_core/data/armarx_lt_mem/2024-09-18-a7a-predef.pkl index b910ec50cdb750ba9019c05a6e124d7633b9da7e..70695f850dd9a58e129a638f376dd117bc8de4ba 100644 Binary files a/python/skills/skills/episodic_verbalization_core/data/armarx_lt_mem/2024-09-18-a7a-predef.pkl and b/python/skills/skills/episodic_verbalization_core/data/armarx_lt_mem/2024-09-18-a7a-predef.pkl differ diff --git a/python/skills/skills/episodic_verbalization_core/em/armarx_lt_mem.py b/python/skills/skills/episodic_verbalization_core/em/armarx_lt_mem.py index 73acc093682a61d4a2b506d28857aaedbd2ec057..b3a14749e4ee1df0eeff7dca5c945dc8168deddb 100644 --- a/python/skills/skills/episodic_verbalization_core/em/armarx_lt_mem.py +++ b/python/skills/skills/episodic_verbalization_core/em/armarx_lt_mem.py @@ -42,9 +42,9 @@ def _iter_all_instances( for s in e.snapshots.values(): for i in s.instances: try: - i.load() # TODO: load only metadata - if i.metadata and i.metadata.timeReferenced > start_from_timestamp_microseconds: - # TODO: and then load data here + i.loadReferences() + if i.metadata.timeReferenced > start_from_timestamp_microseconds: + i.load() yield i except: print("Could not load instance") @@ -193,7 +193,12 @@ def load_episode_from_armarx_lt_mem( scenes = [] running_goal = '' # This is only there to handle actions with no executorName provided for skill_evt_inst in skill_events: - skill_evt = skill_evt_inst.data.to_primitive() + try: + skill_evt = skill_evt_inst.data.to_primitive() + except: + print('Skill evt', skill_evt_inst.metadata, 'appears to be invalid, skipping') + traceback.print_exc() + continue ts = datetime.fromtimestamp(skill_evt_inst.metadata.timeReferenced / 1e6) action = skill_evt['skillId']['skillName'] if action == 'ResetGazeTargets': diff --git a/python/skills/skills/episodic_verbalization_core/llm_emv/emv_api.py b/python/skills/skills/episodic_verbalization_core/llm_emv/emv_api.py index 8feb66945f3c9b65a6002b3137c451a848f72acc..aff671ed977d78216b6e8877e5af79b518508284 100644 --- a/python/skills/skills/episodic_verbalization_core/llm_emv/emv_api.py +++ b/python/skills/skills/episodic_verbalization_core/llm_emv/emv_api.py @@ -52,6 +52,15 @@ class EMVerbalizationAPI: else: self._history = make_tree_interactive(history, search_embedding_fn, search_filter_kwargs).all_leaves + + try: + print('Initializing search embeddings eagerly...') + self._history.search('') + except SemanticHintError: + pass + finally: + self._history.collapse_deep() + ######################### # dialog diff --git a/python/skills/skills/episodic_verbalization_core/llm_emv/setup.py b/python/skills/skills/episodic_verbalization_core/llm_emv/setup.py index 16adc9d02b6ab074155919944ac7fda65696e404..a665d650d5e5e99bf6f11c7d34586d33f64c88ca 100644 --- a/python/skills/skills/episodic_verbalization_core/llm_emv/setup.py +++ b/python/skills/skills/episodic_verbalization_core/llm_emv/setup.py @@ -1,8 +1,10 @@ import datetime -from functools import partial, cache +from concurrent.futures import ThreadPoolExecutor +from functools import partial from pathlib import Path from typing import Optional, List, Tuple +import torch from langchain_core.language_models import BaseChatModel from sentence_transformers import SentenceTransformer @@ -105,12 +107,46 @@ def create_search_embedding_and_cfg(search_cfg: Optional[dict]): embedding_model_name = search_cfg.pop('embedding', 'all-MiniLM-L6-v2') embedding_model = SentenceTransformer(embedding_model_name) + cache = {} + cache_file = Path('search-embedding-cache.pt') + if cache_file.is_file(): + cache = torch.load(cache_file, map_location=embedding_model.device) + write_cache_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix='search-emb-cache-writer') + - @cache # cache requires hashable argument, therefore using tuple here def _embed_cached(texts: Tuple[str, ...]): - return embedding_model.encode(list(texts), convert_to_tensor=True) + result = torch.empty(len(texts), embedding_model.get_sentence_embedding_dimension(), + device=embedding_model.device) + todo_texts, todo_indices = [], [] + for i, text in enumerate(texts): + if text in cache: + result[i] = cache[text] + else: + todo_texts.append(text) + todo_indices.append(i) + + print('Embedding', len(texts), ', new:', len(todo_texts)) + if todo_indices: + new_embeddings = embedding_model.encode(list(todo_texts), convert_to_tensor=True) + result[todo_indices] = new_embeddings + for text, emb in zip(todo_texts, new_embeddings): + cache[text] = emb + write_cache_executor.submit(lambda: torch.save(dict(cache), cache_file)) + return result + def _embed(texts: List[str]): - return _embed_cached(tuple(texts)) + original_to_unique_indices = [] + unique_entries: List[str] = [] + for text in texts: + try: + idx = unique_entries.index(text) + original_to_unique_indices.append(idx) + except ValueError: + original_to_unique_indices.append(len(unique_entries)) + unique_entries.append(text) + embeddings = _embed_cached(tuple(unique_entries)) + return torch.index_select(embeddings, 0, torch.tensor(original_to_unique_indices)) + return _embed, search_cfg.pop('filter_kwargs', {}) diff --git a/python/skills/skills/ice_service.py b/python/skills/skills/ice_service.py index a26221b225332fbbd58609cb03110b802ab83450..18799ac2a0665659604cc510a163585bc364e2d6 100644 --- a/python/skills/skills/ice_service.py +++ b/python/skills/skills/ice_service.py @@ -6,6 +6,7 @@ import pickle from pathlib import Path from armarx_core import ice_manager, slice_loader from episodic_verbalization import EM_Verbalization_Tree +from tree_updater import main as updateTreePeriodically slice_loader.load_armarx_slice( "armarx_speech", "../../armarx/speech/skills/EpisodicVerbalization/core/EpisodicVerbalizationPythonInterface.ice" @@ -78,6 +79,9 @@ def main(): logger.info("Connection was refused, waiting for Ice connection...") sleep(1) # try to connect every second ice_manager.register_object(component, ice_object_name=name) + + updateTreePeriodically() + ice_manager.wait_for_shutdown() if __name__ == '__main__': diff --git a/python/skills/skills/tree_updater.py b/python/skills/skills/tree_updater.py index ebc4d2ff4c6fb6c20cdfc9fd97bd8864ab7fde34..c4fd49967bbded09e2e7f98fb393835b17b8d096 100644 --- a/python/skills/skills/tree_updater.py +++ b/python/skills/skills/tree_updater.py @@ -66,7 +66,7 @@ def read_memory_ids(mns, core_segment_id: mem.MemoryID): else: print(data) - print("Added ", len(result_data), " instance ids") + #print("Added ", len(result_data), " instance ids") return result_data @@ -81,7 +81,7 @@ def get_last_skill_id(skill_event_ids): -def periodicUpdateSkillTree(): +def periodicUpdateSkillTree(mns, core_segment_id): tree = EM_Verbalization_Tree() last_id = 0 while True: @@ -94,10 +94,12 @@ def periodicUpdateSkillTree(): last_id = last_skill_id print("updating tree...") tree.update_tree() - sleep(60 * 3) # update every 3 minutes (average time needed to update tree if new skill was available) + else: + print("No new skill events detected") + sleep(15) # update every 3 minutes (average time needed to update tree if new skill was available) -if __name__ == '__main__': +def main(): print("Running tree updater ...") # Get the Memory Name System. mns = memcl.MemoryNameSystem.wait_for_mns() @@ -105,4 +107,7 @@ if __name__ == '__main__': memory_id = mem.MemoryID("Skill") core_segment_id = memory_id.with_core_segment_name("SkillEvent") - periodicUpdateSkillTree() + periodicUpdateSkillTree(mns, core_segment_id) + +if __name__ == '__main__': + main()