[Bf-blender-cvs] [51885aea9d1] functions: simple inferencer in python with tests

Jacques Lucke noreply at git.blender.org
Tue Mar 12 18:29:44 CET 2019


Commit: 51885aea9d1f12fbe1ec7b6569262523b2f95bbf
Author: Jacques Lucke
Date:   Tue Mar 12 18:29:16 2019 +0100
Branches: functions
https://developer.blender.org/rB51885aea9d1f12fbe1ec7b6569262523b2f95bbf

simple inferencer in python with tests

===================================================================

A	release/scripts/startup/function_nodes/inferencer.py
A	release/scripts/startup/function_nodes/test_inferencer.py
A	release/scripts/startup/function_nodes/tests.py

===================================================================

diff --git a/release/scripts/startup/function_nodes/inferencer.py b/release/scripts/startup/function_nodes/inferencer.py
new file mode 100644
index 00000000000..68bdf9768a0
--- /dev/null
+++ b/release/scripts/startup/function_nodes/inferencer.py
@@ -0,0 +1,88 @@
+class Inferencer:
+    def __init__(self, type_infos):
+        self.type_infos = type_infos
+        self.finalized_ids = dict()
+        self.constraints = set()
+
+    def insert_final_type(self, id, data_type):
+        self.finalized_ids[id] = data_type
+
+    def insert_equality_constraint(self, ids):
+        constraint = EqualityConstraint(ids)
+        self.constraints.add(constraint)
+
+    def insert_list_constraint(self, list_ids, base_ids=tuple()):
+        constraint = ListConstraint(list_ids, base_ids, self.type_infos)
+        self.constraints.add(constraint)
+
+    def finalize_id(self, id, data_type):
+        if id in self.finalized_ids:
+            if self.finalized_ids[id] != data_type:
+                raise CannotInferenceError()
+        else:
+            self.finalized_ids[id] = data_type
+
+    def finalize_ids(self, ids, data_type):
+        for id in ids:
+            self.finalize_id(id, data_type)
+
+    def inference(self):
+        while len(self.constraints) > 0:
+            handled_constraints = set()
+
+            for constraint in self.constraints:
+                if constraint.try_finalize(self.finalized_ids, self.finalize_ids):
+                    handled_constraints.add(constraint)
+
+            if len(handled_constraints) == 0:
+                raise CannotInferenceError()
+
+            self.constraints -= handled_constraints
+
+    def get_final_type(self, id):
+        return self.finalized_ids[id]
+
+
+class Constraint:
+    def try_finalize(self, finalized_ids, do_finalize):
+        raise NotImplementedError()
+
+class EqualityConstraint(Constraint):
+    def __init__(self, ids):
+        self.ids = set(ids)
+
+    def can_be_finalized(self, finalized_ids):
+        return any(id in finalized_ids for id in self.ids)
+
+    def try_finalize(self, finalized_ids, finalize_do):
+        for id in self.ids:
+            if id in finalized_ids:
+                finalize_do(self.ids, finalized_ids[id])
+                return True
+        return False
+
+class ListConstraint(Constraint):
+    def __init__(self, list_ids, base_ids, type_infos):
+        self.list_ids = set(list_ids)
+        self.base_ids = set(base_ids)
+        self.type_infos = type_infos
+
+    def try_finalize(self, finalized_ids, finalize_do):
+        for id in self.list_ids:
+            if id in finalized_ids:
+                list_type = finalized_ids[id]
+                base_type = self.type_infos.to_base(list_type)
+                finalize_do(self.list_ids, list_type)
+                finalize_do(self.base_ids, base_type)
+                return True
+        for id in self.base_ids:
+            if id in finalized_ids:
+                base_type = finalized_ids[id]
+                list_type = self.type_infos.to_list(base_type)
+                finalize_do(self.base_ids, base_type)
+                finalize_do(self.list_ids, list_type)
+                return True
+        return False
+
+class CannotInferenceError(Exception):
+    ...
diff --git a/release/scripts/startup/function_nodes/test_inferencer.py b/release/scripts/startup/function_nodes/test_inferencer.py
new file mode 100644
index 00000000000..94ca6eb4f93
--- /dev/null
+++ b/release/scripts/startup/function_nodes/test_inferencer.py
@@ -0,0 +1,47 @@
+import unittest
+from . inferencer import Inferencer, CannotInferenceError
+from . sockets import info
+
+class TestInferencer(unittest.TestCase):
+    def setUp(self):
+        self.inferencer = Inferencer(info)
+
+    def test_single_equality(self):
+        self.inferencer.insert_equality_constraint((1, 2))
+        self.inferencer.insert_final_type(1, "Float")
+        self.inferencer.inference()
+
+        self.assertEqual(self.inferencer.get_final_type(1), "Float")
+        self.assertEqual(self.inferencer.get_final_type(2), "Float")
+
+    def test_multiple_equality(self):
+        self.inferencer.insert_equality_constraint((1, 2, 3))
+        self.inferencer.insert_equality_constraint((3, 4, 5))
+        self.inferencer.insert_final_type(4, "Integer")
+        self.inferencer.inference()
+
+        self.assertEqual(self.inferencer.get_final_type(1), "Integer")
+        self.assertEqual(self.inferencer.get_final_type(3), "Integer")
+        self.assertEqual(self.inferencer.get_final_type(5), "Integer")
+
+    def test_find_base(self):
+        self.inferencer.insert_list_constraint((1, ), (2,))
+        self.inferencer.insert_final_type(1, "Float List")
+        self.inferencer.inference()
+
+        self.assertEqual(self.inferencer.get_final_type(2), "Float")
+
+    def test_find_list(self):
+        self.inferencer.insert_list_constraint((1, ), (2, ))
+        self.inferencer.insert_final_type(2, "Vector")
+        self.inferencer.inference()
+
+        self.assertEqual(self.inferencer.get_final_type(1), "Vector List")
+
+    def test_invalid_equality(self):
+        self.inferencer.insert_equality_constraint((1, 2))
+        self.inferencer.insert_final_type(1, "Float")
+        self.inferencer.insert_final_type(2, "Integer")
+
+        with self.assertRaises(CannotInferenceError):
+            self.inferencer.inference()
\ No newline at end of file
diff --git a/release/scripts/startup/function_nodes/tests.py b/release/scripts/startup/function_nodes/tests.py
new file mode 100644
index 00000000000..ac42fa38979
--- /dev/null
+++ b/release/scripts/startup/function_nodes/tests.py
@@ -0,0 +1,6 @@
+import unittest
+
+def register():
+    loader = unittest.TestLoader()
+    tests = loader.discover("function_nodes", pattern="test*")
+    unittest.TextTestRunner(verbosity=1).run(tests)
\ No newline at end of file



More information about the Bf-blender-cvs mailing list