Skip to content
Snippets Groups Projects
Commit f4521d31 authored by Joana Plewnia's avatar Joana Plewnia
Browse files

added periodic tree update and made search faster

parent dd5631a8
No related branches found
No related tags found
1 merge request!9Draft: tested version of Non-AI Verbalization Skills on Robot
......@@ -42,9 +42,7 @@ class EM_Verbalization_Tree:
self.tree = HigherLevelSummary('', children=[tree])
self.ltm_base_path = self.ltm_base_path
with open(self.history_cache, 'wb') as file:
pickle.dump(self.tree)
self.history_cache.write_bytes(pickle.dumps(self.tree))
print("made tree")
def update_tree(self):
......@@ -52,8 +50,8 @@ class EM_Verbalization_Tree:
llm = instantiate_llm(llm_option_dict)
tree = extend_existing_history_from_memory_snapshots(self.tree, Path(self.ltm_base_path), llm, LLMBasedSummarizer(llm, example_db_name='armarx_lt_mem'))
self.tree = tree
with open(self.history_cache, 'wb') as file:
pickle.dump(self.tree)
self.history_cache.write_bytes(pickle.dumps(self.tree))
print("updated tree")
def answer_question_with_tree(self, question: str, config_file_name: str= "armarx_lt_mem/full", llm_option_dict:str= "{'type': 'ChatOpenAI', 'model_name': 'gpt-4o-mini', 'request_timeout': 30, 'max_retries': 2}"):
print("Triggered question answering...")
......
No preview for this file type
......@@ -42,9 +42,9 @@ def _iter_all_instances(
for s in e.snapshots.values():
for i in s.instances:
try:
i.load() # TODO: load only metadata
if i.metadata and i.metadata.timeReferenced > start_from_timestamp_microseconds:
# TODO: and then load data here
i.loadReferences()
if i.metadata.timeReferenced > start_from_timestamp_microseconds:
i.load()
yield i
except:
print("Could not load instance")
......@@ -193,7 +193,12 @@ def load_episode_from_armarx_lt_mem(
scenes = []
running_goal = '' # This is only there to handle actions with no executorName provided
for skill_evt_inst in skill_events:
skill_evt = skill_evt_inst.data.to_primitive()
try:
skill_evt = skill_evt_inst.data.to_primitive()
except:
print('Skill evt', skill_evt_inst.metadata, 'appears to be invalid, skipping')
traceback.print_exc()
continue
ts = datetime.fromtimestamp(skill_evt_inst.metadata.timeReferenced / 1e6)
action = skill_evt['skillId']['skillName']
if action == 'ResetGazeTargets':
......
......@@ -52,6 +52,15 @@ class EMVerbalizationAPI:
else:
self._history = make_tree_interactive(history, search_embedding_fn, search_filter_kwargs).all_leaves
try:
print('Initializing search embeddings eagerly...')
self._history.search('')
except SemanticHintError:
pass
finally:
self._history.collapse_deep()
#########################
# dialog
......
import datetime
from functools import partial, cache
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from typing import Optional, List, Tuple
import torch
from langchain_core.language_models import BaseChatModel
from sentence_transformers import SentenceTransformer
......@@ -105,12 +107,46 @@ def create_search_embedding_and_cfg(search_cfg: Optional[dict]):
embedding_model_name = search_cfg.pop('embedding', 'all-MiniLM-L6-v2')
embedding_model = SentenceTransformer(embedding_model_name)
cache = {}
cache_file = Path('search-embedding-cache.pt')
if cache_file.is_file():
cache = torch.load(cache_file, map_location=embedding_model.device)
write_cache_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix='search-emb-cache-writer')
@cache # cache requires hashable argument, therefore using tuple here
def _embed_cached(texts: Tuple[str, ...]):
return embedding_model.encode(list(texts), convert_to_tensor=True)
result = torch.empty(len(texts), embedding_model.get_sentence_embedding_dimension(),
device=embedding_model.device)
todo_texts, todo_indices = [], []
for i, text in enumerate(texts):
if text in cache:
result[i] = cache[text]
else:
todo_texts.append(text)
todo_indices.append(i)
print('Embedding', len(texts), ', new:', len(todo_texts))
if todo_indices:
new_embeddings = embedding_model.encode(list(todo_texts), convert_to_tensor=True)
result[todo_indices] = new_embeddings
for text, emb in zip(todo_texts, new_embeddings):
cache[text] = emb
write_cache_executor.submit(lambda: torch.save(dict(cache), cache_file))
return result
def _embed(texts: List[str]):
return _embed_cached(tuple(texts))
original_to_unique_indices = []
unique_entries: List[str] = []
for text in texts:
try:
idx = unique_entries.index(text)
original_to_unique_indices.append(idx)
except ValueError:
original_to_unique_indices.append(len(unique_entries))
unique_entries.append(text)
embeddings = _embed_cached(tuple(unique_entries))
return torch.index_select(embeddings, 0, torch.tensor(original_to_unique_indices))
return _embed, search_cfg.pop('filter_kwargs', {})
......@@ -6,6 +6,7 @@ import pickle
from pathlib import Path
from armarx_core import ice_manager, slice_loader
from episodic_verbalization import EM_Verbalization_Tree
from tree_updater import main as updateTreePeriodically
slice_loader.load_armarx_slice(
"armarx_speech", "../../armarx/speech/skills/EpisodicVerbalization/core/EpisodicVerbalizationPythonInterface.ice"
......@@ -78,6 +79,9 @@ def main():
logger.info("Connection was refused, waiting for Ice connection...")
sleep(1) # try to connect every second
ice_manager.register_object(component, ice_object_name=name)
updateTreePeriodically()
ice_manager.wait_for_shutdown()
if __name__ == '__main__':
......
......@@ -66,7 +66,7 @@ def read_memory_ids(mns, core_segment_id: mem.MemoryID):
else:
print(data)
print("Added ", len(result_data), " instance ids")
#print("Added ", len(result_data), " instance ids")
return result_data
......@@ -81,7 +81,7 @@ def get_last_skill_id(skill_event_ids):
def periodicUpdateSkillTree():
def periodicUpdateSkillTree(mns, core_segment_id):
tree = EM_Verbalization_Tree()
last_id = 0
while True:
......@@ -94,10 +94,12 @@ def periodicUpdateSkillTree():
last_id = last_skill_id
print("updating tree...")
tree.update_tree()
sleep(60 * 3) # update every 3 minutes (average time needed to update tree if new skill was available)
else:
print("No new skill events detected")
sleep(15) # update every 3 minutes (average time needed to update tree if new skill was available)
if __name__ == '__main__':
def main():
print("Running tree updater ...")
# Get the Memory Name System.
mns = memcl.MemoryNameSystem.wait_for_mns()
......@@ -105,4 +107,7 @@ if __name__ == '__main__':
memory_id = mem.MemoryID("Skill")
core_segment_id = memory_id.with_core_segment_name("SkillEvent")
periodicUpdateSkillTree()
periodicUpdateSkillTree(mns, core_segment_id)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment