[Bf-blender-cvs] [9964eed] master: PyAPI: add optional filter argument to KDTree.find

Campbell Barton noreply at git.blender.org
Sun Dec 6 11:46:49 CET 2015


Commit: 9964eed9ac7547db4c58bf5eabb786440236b138
Author: Campbell Barton
Date:   Sun Dec 6 21:33:39 2015 +1100
Branches: master
https://developer.blender.org/rB9964eed9ac7547db4c58bf5eabb786440236b138

PyAPI: add optional filter argument to KDTree.find

===================================================================

M	source/blender/python/mathutils/mathutils_kdtree.c
M	tests/python/bl_pyapi_mathutils.py

===================================================================

diff --git a/source/blender/python/mathutils/mathutils_kdtree.c b/source/blender/python/mathutils/mathutils_kdtree.c
index dc1e82a..ca66c19 100644
--- a/source/blender/python/mathutils/mathutils_kdtree.c
+++ b/source/blender/python/mathutils/mathutils_kdtree.c
@@ -189,26 +189,57 @@ static PyObject *py_kdtree_balance(PyKDTree *self)
 	Py_RETURN_NONE;
 }
 
+struct PyKDTree_NearestData {
+	PyObject *py_filter;
+	bool is_error;
+};
+
+static int py_find_nearest_cb(void *user_data, int index, const float co[3], float dist_sq)
+{
+	UNUSED_VARS(co, dist_sq);
+
+	struct PyKDTree_NearestData *data = user_data;
+
+	PyObject *py_args = PyTuple_New(1);
+	PyTuple_SET_ITEM(py_args, 0, PyLong_FromLong(index));
+	PyObject *result = PyObject_CallObject(data->py_filter, py_args);
+	Py_DECREF(py_args);
+
+	if (result) {
+		bool use_node;
+		int ok = PyC_ParseBool(result, &use_node);
+		Py_DECREF(result);
+		if (ok) {
+			return (int)use_node;
+		}
+	}
+
+	data->is_error = true;
+	return -1;
+}
+
 PyDoc_STRVAR(py_kdtree_find_doc,
-".. method:: find(co)\n"
+".. method:: find(co, filter=None)\n"
 "\n"
 "   Find nearest point to ``co``.\n"
 "\n"
 "   :arg co: 3d coordinates.\n"
 "   :type co: float triplet\n"
+"   :arg filter: function which takes an index and returns True for indices to include in the search.\n"
+"   :type filter: callable\n"
 "   :return: Returns (:class:`Vector`, index, distance).\n"
 "   :rtype: :class:`tuple`\n"
 );
 static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs)
 {
-	PyObject *py_co;
+	PyObject *py_co, *py_filter = NULL;
 	float co[3];
 	KDTreeNearest nearest;
-	const char *keywords[] = {"co", NULL};
+	const char *keywords[] = {"co", "filter", NULL};
 
 	if (!PyArg_ParseTupleAndKeywords(
-	        args, kwargs, (char *) "O:find", (char **)keywords,
-	        &py_co))
+	        args, kwargs, (char *) "O|O:find", (char **)keywords,
+	        &py_co, &py_filter))
 	{
 		return NULL;
 	}
@@ -221,10 +252,26 @@ static PyObject *py_kdtree_find(PyKDTree *self, PyObject *args, PyObject *kwargs
 		return NULL;
 	}
 
-
 	nearest.index = -1;
 
-	BLI_kdtree_find_nearest(self->obj, co, &nearest);
+	if (py_filter == NULL) {
+		BLI_kdtree_find_nearest(self->obj, co, &nearest);
+	}
+	else {
+		struct PyKDTree_NearestData data = {0};
+
+		data.py_filter = py_filter;
+		data.is_error = false;
+
+		BLI_kdtree_find_nearest_cb(
+		        self->obj, co,
+		        py_find_nearest_cb, &data,
+		        &nearest);
+
+		if (data.is_error) {
+			return NULL;
+		}
+	}
 
 	return kdtree_nearest_to_py_and_check(&nearest);
 }
diff --git a/tests/python/bl_pyapi_mathutils.py b/tests/python/bl_pyapi_mathutils.py
index b7f61df..7761b6c 100644
--- a/tests/python/bl_pyapi_mathutils.py
+++ b/tests/python/bl_pyapi_mathutils.py
@@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase):
 
 
 class KDTreeTesting(unittest.TestCase):
-
     @staticmethod
-    def kdtree_create_grid_3d(tot):
-        k = kdtree.KDTree(tot * tot * tot)
+    def kdtree_create_grid_3d_data(tot):
         index = 0
         mul = 1.0 / (tot - 1)
         for x in range(tot):
             for y in range(tot):
                 for z in range(tot):
-                    k.insert((x * mul, y * mul, z * mul), index)
+                    yield (x * mul, y * mul, z * mul), index
                     index += 1
+
+    @staticmethod
+    def kdtree_create_grid_3d(tot, *, filter_fn=None):
+        k = kdtree.KDTree(tot * tot * tot)
+        for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot):
+            if (filter_fn is not None) and (not filter_fn(co, index)):
+                continue
+            k.insert(co, index)
         k.balance()
         return k
 
@@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase):
         ret = k.find_n((1.0,) * 3, tot)
         self.assertEqual(len(ret), tot)
 
+    def test_kdtree_grid_filter_simple(self):
+        size = 10
+        k = self.kdtree_create_grid_3d(size)
+
+        # filter exact index
+        ret_regular = k.find((1.0,) * 3)
+        ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1])
+        self.assertEqual(ret_regular, ret_filter)
+        ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1])
+        self.assertEqual(ret_regular[:2], ret_filter[:2])  # ignore distance
+
+    def test_kdtree_grid_filter_pairs(self):
+        size = 10
+        k_all = self.kdtree_create_grid_3d(size)
+        k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1)
+        k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0)
+
+        samples = 5
+        mul = 1 / (samples - 1)
+        for x in range(samples):
+            for y in range(samples):
+                for z in range(samples):
+                    co = (x * mul, y * mul, z * mul)
+
+                    ret_regular = k_odd.find(co)
+                    self.assertEqual(ret_regular[1] % 2, 1)
+                    ret_filter = k_all.find(co, lambda i: (i % 2) == 1)
+                    self.assertEqual(ret_regular, ret_filter)
+
+                    ret_regular = k_evn.find(co)
+                    self.assertEqual(ret_regular[1] % 2, 0)
+                    ret_filter = k_all.find(co, lambda i: (i % 2) == 0)
+                    self.assertEqual(ret_regular, ret_filter)
+
+
+        # filter out all values (search odd tree for even values and the reverse)
+        co = (0,) * 3
+        ret_filter = k_odd.find(co, lambda i: (i % 2) == 0)
+        self.assertEqual(ret_filter[1], None)
+
+        ret_filter = k_evn.find(co, lambda i: (i % 2) == 1)
+        self.assertEqual(ret_filter[1], None)
+
     def test_kdtree_invalid_size(self):
         with self.assertRaises(ValueError):
             kdtree.KDTree(-1)
@@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase):
         with self.assertRaises(RuntimeError):
             k.find(co)
 
+    def test_kdtree_invalid_filter(self):
+        k = kdtree.KDTree(1)
+        k.insert((0,) * 3, 0)
+        k.balance()
+        # not callable
+        with self.assertRaises(TypeError):
+            k.find((0,) * 3, filter=None)
+        # no args
+        with self.assertRaises(TypeError):
+            k.find((0,) * 3, filter=lambda: None)
+        # bad return value
+        with self.assertRaises(ValueError):
+            k.find((0,) * 3, filter=lambda i: None)
+
+
 if __name__ == '__main__':
     import sys
     sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else [])




More information about the Bf-blender-cvs mailing list