From e661de3567ad32e3aa691b6eb44785d6aedab7e4 Mon Sep 17 00:00:00 2001
From: Rainer Kartmann <rainer.kartmann@kit.edu>
Date: Fri, 3 Jun 2022 18:31:43 +0200
Subject: [PATCH] Add generation of C++ code from sympy expression

---
 .../hemisphere_joint_demo/sympy_to_code.py    | 278 ++++++++++++++++++
 python/hemisphere-joint-demo/pyproject.toml   |   6 +-
 2 files changed, 283 insertions(+), 1 deletion(-)
 create mode 100644 python/hemisphere-joint-demo/hemisphere_joint_demo/sympy_to_code.py

diff --git a/python/hemisphere-joint-demo/hemisphere_joint_demo/sympy_to_code.py b/python/hemisphere-joint-demo/hemisphere_joint_demo/sympy_to_code.py
new file mode 100644
index 000000000..59013a102
--- /dev/null
+++ b/python/hemisphere-joint-demo/hemisphere_joint_demo/sympy_to_code.py
@@ -0,0 +1,278 @@
+import dataclasses as dc
+import os.path
+
+import sympy as sp
+import typing as ty
+
+from collections import OrderedDict
+
+
+@dc.dataclass
+class Line:
+    lhs: str
+    rhs: str
+
+    def make_decl(self) -> str:
+        return f"double {self.lhs} = 0;"
+
+    def make_impl(self):
+        return f"{self.lhs} = {self.rhs};"
+
+    @classmethod
+    def lhs_from_expr(cls, expr: sp.Basic) -> str:
+        lhs = ("_" + str(expr)
+               .replace(" ", "")
+               .replace("**", "_pow_")
+               .replace("+", "_add_")
+               .replace("-", "_add_")
+               .replace("*", "_mul_")
+               .replace("/", "_div_")
+               .replace("(", "__lpar_")
+               .replace(")", "_rpar__")
+               )
+        return lhs
+
+    @classmethod
+    def rhs_from_expr(cls, expr: sp.Basic, ctx: "Context") -> str:
+        # Recurse.
+        code_args = [expr_to_cpp(arg, ctx) for arg in expr.args]
+
+        def par(code: str) -> str:
+            return f"({code})"
+
+        def op(op: str) -> str:
+            return par(f" {op} ".join(code_args))
+
+        def fn(func: str) -> str:
+            return f"{func}({', '.join(code_args)})"
+
+        if isinstance(expr, sp.Add):
+            return op("+")
+
+        elif isinstance(expr, sp.Subs):
+            return op("-")
+
+        elif isinstance(expr, sp.Mul):
+            # Special case: a/b = a * (b^-1)
+            return op("*")
+
+        elif isinstance(expr, sp.Pow):
+            assert len(code_args) == 2
+            base, exponent = code_args
+            if exponent == "-1":
+                return par(f"1 / {base}")
+            elif exponent == "2":
+                return par(f"{base} * {base}")
+            else:
+                return fn("std::pow")
+
+        elif isinstance(expr, sp.sin):
+            return fn("std::sin")
+
+        elif isinstance(expr, sp.cos):
+            return fn("std::cos")
+
+        else:
+            raise TypeError(f"{expr.__class__}, {expr.func}({expr.args}) = {sp.srepr(expr)}")
+
+
+    @classmethod
+    def from_expr(cls, expr: sp.Basic, ctx: "Context"):
+        lhs = Line.lhs_from_expr(expr)
+        rhs = Line.rhs_from_expr(expr, ctx=ctx)
+        return cls(lhs=lhs, rhs=rhs)
+
+
+@dc.dataclass
+class Context:
+
+    name = "Expressions"
+
+    function_args: ty.List[sp.Symbol] = dc.field(default_factory=list)
+    named_expressions: ty.OrderedDict[sp.Basic, Line] = dc.field(default_factory=OrderedDict)
+    function_results: ty.Dict[str, sp.Basic] = dc.field(default_factory=list)
+
+    depth = 0
+    indent = " " * 4
+
+    def make_compute_args(self):
+        return ", ".join(f"double {arg}" for arg in self.function_args)
+
+    def make_compute_signature_decl(self):
+        return f"void compute({self.make_compute_args()});"
+
+    def make_compute_signature_impl(self):
+        return f"void {self.name}::compute({self.make_compute_args()})"
+
+    def make_decl_lines(self) -> ty.List[str]:
+        lines = self._line_sum(
+            [
+                f"class {self.name}",
+                "{",
+                "public:",
+                self.indent + "",
+                self.indent + self.make_compute_signature_decl(),
+                self.indent + "",
+                self.indent + "// Input arguments:"
+            ],
+            [self.indent + f"double {arg} = 0;" for arg in self.function_args],
+            [
+                self.indent + "",
+                self.indent + "// Results:"
+            ],
+            [self.indent + f"double {res} = 0;" for res in self.function_results],
+            [
+                self.indent + "",
+                self.indent + "// Intermediate expressions:"
+            ],
+            [self.indent + line.make_decl() for expr, line in self.named_expressions.items()],
+            [
+                self.indent + "",
+                "};",
+            ]
+        )
+        return lines
+
+
+    def make_impl_lines(self) -> ty.List[str]:
+        lines = self._line_sum(
+            [
+                self.make_compute_signature_impl(),
+                "{"
+            ],
+            [self.indent + f"this->{arg} = {arg};" for arg in self.function_args],
+            [
+                self.indent + "",
+            ],
+            [self.indent + line.make_impl() for expr, line in self.named_expressions.items()],
+            [
+                self.indent + "",
+            ],
+            [self.indent + Line(lhs=res, rhs=Line.lhs_from_expr(expr)).make_impl()
+             for res, expr in self.function_results.items()],
+            [
+                "}",
+            ],
+        )
+        return lines
+
+    def make_header_lines(self):
+        lines = self._line_sum(
+            ["#pragma once"],
+            [""] * 2,
+            self.make_decl_lines(),
+            [""] * 1,
+        )
+        return lines
+
+    def make_source_lines(self):
+        lines = self._line_sum(
+            [f'#include "{self.name}.h"'],
+            [""] * 1,
+            ["#include <cmath>"],
+            [""] * 2,
+            self.make_impl_lines(),
+            [""] * 1,
+        )
+        return lines
+
+    @classmethod
+    def format_lines(cls, lines: ty.List[str], line_numbers=False) -> str:
+        if line_numbers:
+            lines = [f"{i:>3} | {line}" for i, line in enumerate(lines)]
+        return "\n".join(lines)
+
+    @classmethod
+    def write_lines(cls, lines: ty.List[str], filepath: str):
+        with open(filepath, "w") as file:
+            file.writelines([l.rstrip() + "\n" for l in lines])
+
+    def _line_sum(self, *args):
+        return sum(args, [])
+
+    def build(self):
+        for name, expr in self.function_results.items():
+            expr_to_cpp(expr, self)
+
+
+def expr_to_cpp(
+        expr: sp.Basic,
+        ctx: Context,
+) -> str:
+    indent = "  " * ctx.depth
+
+    if len(expr.args) == 0:
+        # Leaf.
+        print(f"{indent}Leaf: {expr}")
+
+        if isinstance(expr, sp.Symbol):
+            # Must be part of local variables.
+            assert expr in ctx.function_args
+
+        elif isinstance(expr, sp.Number):
+            # Will be turned into a literal.
+            pass
+
+        else:
+            raise TypeError(f"Got expr {expr} of type {type(expr)}")
+
+        return str(expr)
+
+    else:
+        # Non-leaf
+        print(f"{indent}Node: {expr}")
+        ctx.depth += 1
+
+        line = Line.from_expr(expr, ctx=ctx)
+        ctx.named_expressions[expr] = line
+
+        ctx.depth -= 1
+
+        return line.lhs
+
+
+
+if __name__ == '__main__':
+    from sympy import sin, cos, sqrt
+
+    # Actuation
+    a1, a2 = sp.symbols('a1 a2')
+    # Constants defining deometry
+    lever, theta0 = sp.symbols('lever theta0')
+    # P1_z=f(motor1)
+    # P1_z=f(motor2)
+
+    # for 90°
+    ex = 2 * lever * (lever ** 5 * sin(theta0) - lever ** 3 * a1 * a2 * sin(theta0) - lever ** 3 * a2 ** 2 * sin(theta0) + lever * a1 * a2 ** 3 * sin(theta0) - a2 * sqrt(-2 * lever ** 8 * sin(theta0) ** 2 + lever ** 8 + 2 * lever ** 6 * a1 ** 2 * sin(theta0) ** 2 - lever ** 6 * a1 ** 2 + 2 * lever ** 6 * a1 * a2 * sin(theta0) ** 2 + 2 * lever ** 6 * a2 ** 2 * sin(theta0) ** 2 - lever ** 6 * a2 ** 2 - 2 * lever ** 4 * a1 ** 3 * a2 * sin(theta0) ** 2 - 2 * lever ** 4 * a1 ** 2 * a2 ** 2 * sin(theta0) ** 2 - 2 * lever ** 4 * a1 * a2 ** 3 * sin(theta0) ** 2 + lever ** 2 * a1 ** 4 * a2 ** 2 + 2 * lever ** 2 * a1 ** 3 * a2 ** 3 * sin(theta0) ** 2 + lever ** 2 * a1 ** 2 * a2 ** 4 - a1 ** 4 * a2 ** 4)) * sin(theta0) / (sqrt(lever ** 2 - a2 ** 2) * (lever ** 4 - a1 ** 2 * a2 ** 2))
+    ey = 2 * lever * (lever ** 5 * sin(theta0) - lever ** 3 * a1 ** 2 * sin(theta0) - lever ** 3 * a1 * a2 * sin(theta0) + lever * a1 ** 3 * a2 * sin(theta0) - a1 * sqrt(-2 * lever ** 8 * sin(theta0) ** 2 + lever ** 8 + 2 * lever ** 6 * a1 ** 2 * sin(theta0) ** 2 - lever ** 6 * a1 ** 2 + 2 * lever ** 6 * a1 * a2 * sin(theta0) ** 2 + 2 * lever ** 6 * a2 ** 2 * sin(theta0) ** 2 - lever ** 6 * a2 ** 2 - 2 * lever ** 4 * a1 ** 3 * a2 * sin(theta0) ** 2 - 2 * lever ** 4 * a1 ** 2 * a2 ** 2 * sin(theta0) ** 2 - 2 * lever ** 4 * a1 * a2 ** 3 * sin(theta0) ** 2 + lever ** 2 * a1 ** 4 * a2 ** 2 + 2 * lever ** 2 * a1 ** 3 * a2 ** 3 * sin(theta0) ** 2 + lever ** 2 * a1 ** 2 * a2 ** 4 - a1 ** 4 * a2 ** 4)) * sin(theta0) / (sqrt(lever ** 2 - a1 ** 2) * (lever ** 4 - a1 ** 2 * a2 ** 2))
+    ez = 2 * lever * (lever * (lever ** 4 - a1 ** 2 * a2 ** 2) * (a1 + a2) * sin(theta0) + (lever ** 2 + a1 * a2) * sqrt(lever ** 2 * (lever ** 2 * a1 + lever ** 2 * a2 - a1 ** 2 * a2 - a1 * a2 ** 2) ** 2 * sin(theta0) ** 2 + (-lever ** 4 + a1 ** 2 * a2 ** 2) * (-2 * lever ** 4 * cos(theta0) ** 2 + lever ** 4 + lever ** 2 * a1 ** 2 * cos(theta0) ** 2 + lever ** 2 * a2 ** 2 * cos(theta0) ** 2 - a1 ** 2 * a2 ** 2))) * sin(theta0) / ((lever ** 2 + a1 * a2) * (lever ** 4 - a1 ** 2 * a2 ** 2))
+
+    ctx = Context(
+        function_args=[a1, a2, lever, theta0],
+        function_results={
+            "ex": ex,
+            "ey": ey,
+            "ez": ez,
+        }
+    )
+    ctx.build()
+
+    output_dir = "/home/rkartmann/code/simox/VirtualRobot/examples/HemisphereJoint/"
+    header_path = os.path.join(output_dir, ctx.name + ".h")
+    source_path = os.path.join(output_dir, ctx.name + ".cpp")
+
+    header_lines = ctx.make_header_lines()
+    source_lines = ctx.make_source_lines()
+
+    print("Declaration:")
+    print(ctx.format_lines(header_lines, line_numbers=True))
+    print("Implementation:")
+    print(ctx.format_lines(source_lines, line_numbers=True))
+
+    print("Writing files...")
+    print(f"- {header_path}")
+    print(ctx.write_lines(header_lines, header_path))
+    print(f"- {source_path}")
+    print(ctx.write_lines(source_lines, source_path))
+
+    print("Done.")
diff --git a/python/hemisphere-joint-demo/pyproject.toml b/python/hemisphere-joint-demo/pyproject.toml
index 7c9d0cb1f..4a155e8e7 100644
--- a/python/hemisphere-joint-demo/pyproject.toml
+++ b/python/hemisphere-joint-demo/pyproject.toml
@@ -6,10 +6,14 @@ authors = ["Rainer Kartmann <rainer dot kartmann at kit dot edu>"]
 
 [tool.poetry.dependencies]
 python = "^3.6"
-armarx-dev = { path = "../../../armarx/python3-armarx/", develop=true} # "^0.16.4"
+# armarx-dev = { path = "../../../armarx/python3-armarx/", develop=true} # "^0.16.4"
 zeroc-ice = "3.7.0"
 numpy = "^1.19.5"
 jupyter = "^1.0.0"
+plotly = "*"
+matplotlib = "*"
+scipy = "*"
+sympy = "*"
 
 [tool.poetry.dev-dependencies]
 pytest = "^5.2"
-- 
GitLab