Source code for gunshotmatch_pipeline.decision_tree.export

#!/usr/bin/env python3
#
#  export.py
"""
Export and load decision trees to/from JSON-safe dictionaries..

.. versionadded:: 0.6.0
"""
#
#  Copyright © 2023 Dominic Davis-Foster <dominic@davis-foster.co.uk>
#
#  Based on https://github.com/mlrequest/sklearn-json
#  Copyright (c) 2019 Mathieu Rodrigue
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
#  EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
#  MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
#  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
#  DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
#  OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
#  OR OTHER DEALINGS IN THE SOFTWARE.
#

# stdlib
from typing import Any, Dict, Tuple

# 3rd party
import numpy
from sklearn.ensemble import RandomForestClassifier  # type: ignore[import-untyped]
from sklearn.tree import DecisionTreeClassifier  # type: ignore[import-untyped]
from sklearn.tree._tree import Tree  # type: ignore[import-untyped]

__all__ = [
		"serialise_decision_tree",
		"deserialise_decision_tree",
		"verify_saved_decision_tree",
		"serialise_random_forest",
		"deserialise_random_forest",
		"verify_saved_random_forest",
		]


def _serialise_tree(tree: Tree) -> Tuple[Dict[str, Any], numpy.dtype]:
	serialised_tree = tree.__getstate__()

	dtypes = serialised_tree["nodes"].dtype
	serialised_tree["nodes"] = serialised_tree["nodes"].tolist()
	serialised_tree["values"] = serialised_tree["values"].tolist()

	return serialised_tree, dtypes


def _deserialise_tree(tree_dict: Dict[str, Any], n_features: int, n_classes: int, n_outputs: int) -> Tree:
	tree_dict["nodes"] = [tuple(lst) for lst in tree_dict["nodes"]]

	names = [
			"left_child",
			"right_child",
			"feature",
			"threshold",
			"impurity",
			"n_node_samples",
			"weighted_n_node_samples",
			"missing_go_to_left",
			]
	tree_dict["nodes"] = numpy.array(
			tree_dict["nodes"],
			dtype=numpy.dtype({"names": names, "formats": tree_dict["nodes_dtype"]}),
			)
	tree_dict["values"] = numpy.array(tree_dict["values"])

	tree = Tree(n_features, numpy.array([n_classes], dtype=numpy.intp), n_outputs)
	tree.__setstate__(tree_dict)

	return tree


[docs]def serialise_decision_tree(model: DecisionTreeClassifier) -> Dict[str, Any]: """ Serialise a decision tree to a JSON-safe dictionary. :param model: Trained decision tree. """ tree, dtypes = _serialise_tree(model.tree_) serialised_model = { "classes_": model.classes_.tolist(), "feature_importances_": model.feature_importances_.tolist(), "max_features_": model.max_features_, "n_classes_": int(model.n_classes_), "n_features_in_": model.n_features_in_, "n_outputs_": model.n_outputs_, "tree_": tree, "params": model.get_params(), } if hasattr(model, "feature_names_in_"): serialised_model["feature_names_in_"] = model.feature_names_in_.tolist() tree_dtypes = [] for i in range(0, len(dtypes)): # type: ignore[arg-type] tree_dtypes.append(dtypes[i].str) serialised_model["tree_"]["nodes_dtype"] = tree_dtypes return serialised_model
[docs]def deserialise_decision_tree(model_dict: Dict[str, Any]) -> DecisionTreeClassifier: """ Deserialise a decision tree. :param model_dict: JSON-safe representation of the decision tree. """ deserialised_model = DecisionTreeClassifier(**model_dict["params"]) deserialised_model.classes_ = numpy.array(model_dict["classes_"]) deserialised_model.max_features_ = model_dict["max_features_"] deserialised_model.n_classes_ = model_dict["n_classes_"] deserialised_model.n_features_in_ = model_dict["n_features_in_"] deserialised_model.n_outputs_ = model_dict["n_outputs_"] if "feature_names_in_" in model_dict: deserialised_model.feature_names_in_ = model_dict["feature_names_in_"] tree = _deserialise_tree( model_dict["tree_"], model_dict["n_features_in_"], model_dict["n_classes_"], model_dict["n_outputs_"], ) deserialised_model.tree_ = tree return deserialised_model
[docs]def serialise_random_forest(model: RandomForestClassifier) -> Dict[str, Any]: """ Serialise a random forest to a JSON-safe dictionary. :param model: Trained random forest. """ serialised_model = { "max_depth": model.max_depth, "min_samples_split": model.min_samples_split, "min_samples_leaf": model.min_samples_leaf, "min_weight_fraction_leaf": model.min_weight_fraction_leaf, "max_features": model.max_features, "max_leaf_nodes": model.max_leaf_nodes, "min_impurity_decrease": model.min_impurity_decrease, "n_features_in_": model.n_features_in_, "n_outputs_": model.n_outputs_, "classes_": model.classes_.tolist(), "estimators_": [serialise_decision_tree(decision_tree) for decision_tree in model.estimators_], "params": model.get_params(), } if hasattr(model, "oob_score_"): serialised_model["oob_score_"] = model.oob_score_ if hasattr(model, "oob_decision_function_"): serialised_model["oob_decision_function_"] = model.oob_decision_function_.tolist() if hasattr(model, "feature_names_in_"): serialised_model["feature_names_in_"] = model.feature_names_in_.tolist() if isinstance(model.n_classes_, int): serialised_model["n_classes_"] = model.n_classes_ else: serialised_model["n_classes_"] = model.n_classes_.tolist() return serialised_model
[docs]def deserialise_random_forest(model_dict: Dict[str, Any]) -> RandomForestClassifier: """ Deserialise a random forest. :param model_dict: JSON-safe representation of the random forest. """ model = RandomForestClassifier(**model_dict["params"]) estimators = [deserialise_decision_tree(decision_tree) for decision_tree in model_dict["estimators_"]] model.estimators_ = numpy.array(estimators) model.classes_ = numpy.array(model_dict["classes_"]) model.n_features_in_ = model_dict["n_features_in_"] model.n_outputs_ = model_dict["n_outputs_"] model.max_depth = model_dict["max_depth"] model.min_samples_split = model_dict["min_samples_split"] model.min_samples_leaf = model_dict["min_samples_leaf"] model.min_weight_fraction_leaf = model_dict["min_weight_fraction_leaf"] model.max_features = model_dict["max_features"] model.max_leaf_nodes = model_dict["max_leaf_nodes"] model.min_impurity_decrease = model_dict["min_impurity_decrease"] if "oob_score_" in model_dict: model.oob_score_ = model_dict["oob_score_"] if "oob_decision_function_" in model_dict: model.oob_decision_function_ = model_dict["oob_decision_function_"] if "feature_names_in_" in model_dict: model.feature_names_in_ = model_dict["feature_names_in_"] if isinstance(model_dict["n_classes_"], list): model.n_classes_ = numpy.array(model_dict["n_classes_"]) else: model.n_classes_ = model_dict["n_classes_"] return model
[docs]def verify_saved_decision_tree( in_process: DecisionTreeClassifier, from_file: DecisionTreeClassifier, ) -> None: """ Verify the saved :class:`~sklearn.tree.DecisionTreeClassifier` matches the model in memory. Will raise an :exc:`AssertionError` if the data do not match. :param in_process: The :class:`~sklearn.tree.DecisionTreeClassifier` already in memory. :param from_file: A :class:`~sklearn.tree.DecisionTreeClassifier` loaded from disk. :rtype: .. versionadded:: 0.7.0 .. latex:clearpage:: """ a, b = in_process, from_file assert a.ccp_alpha == b.ccp_alpha, (a.ccp_alpha, b.ccp_alpha) assert a.class_weight == b.class_weight, (a.class_weight, b.class_weight) assert numpy.array_equal(a.classes_, b.classes_), (a.classes_, b.classes_) assert a.criterion == b.criterion, (a.criterion, b.criterion) assert a.max_depth == b.max_depth, (a.max_depth, b.max_depth) assert a.max_features == b.max_features, (a.max_features, b.max_features) assert a.max_features_ == b.max_features_, (a.max_features_, b.max_features_) assert a.max_leaf_nodes == b.max_leaf_nodes, (a.max_leaf_nodes, b.max_leaf_nodes) assert a.min_impurity_decrease == b.min_impurity_decrease, (a.min_impurity_decrease, b.min_impurity_decrease) assert a.min_samples_leaf == b.min_samples_leaf, (a.min_samples_leaf, b.min_samples_leaf) assert a.min_samples_split == b.min_samples_split, (a.min_samples_split, b.min_samples_split) assert a.min_weight_fraction_leaf == b.min_weight_fraction_leaf, ( a.min_weight_fraction_leaf, b.min_weight_fraction_leaf, ) assert a.n_classes_ == b.n_classes_, (a.n_classes_, b.n_classes_) assert a.n_features_in_ == b.n_features_in_, (a.n_features_in_, b.n_features_in_) assert a.n_outputs_ == b.n_outputs_, (a.n_outputs_, b.n_outputs_) assert a.random_state == b.random_state, (a.random_state, b.random_state) assert a.splitter == b.splitter, (a.splitter, b.splitter) a_tree = a.tree_.__getstate__() b_tree = b.tree_.__getstate__() assert a_tree["max_depth"] == b_tree["max_depth"], (a_tree["max_depth"], b_tree["max_depth"]) assert a_tree["node_count"] == b_tree["node_count"], (a_tree["node_count"], b_tree["node_count"]) assert numpy.array_equal(a_tree["nodes"], b_tree["nodes"]), (a_tree["nodes"], b_tree["nodes"])
[docs]def verify_saved_random_forest( in_process: RandomForestClassifier, from_file: RandomForestClassifier, ) -> None: """ Verify the saved :class:`~sklearn.ensemble.RandomForestClassifier` matches the model in memory. Will raise an :exc:`AssertionError` if the data do not match. :param in_process: The :class:`~sklearn.ensemble.RandomForestClassifier` already in memory. :param from_file: A :class:`~sklearn.ensemble.RandomForestClassifier` loaded from disk. :rtype: .. versionadded:: 0.7.0 """ a, b = in_process, from_file assert a.max_depth == b.max_depth, (a.max_depth, b.max_depth) assert a.max_features == b.max_features, (a.max_features, b.max_features) assert a.max_leaf_nodes == b.max_leaf_nodes, (a.max_leaf_nodes, b.max_leaf_nodes) assert a.max_samples == b.max_samples, (a.max_samples, b.max_samples) assert a.min_impurity_decrease == b.min_impurity_decrease, (a.min_impurity_decrease, b.min_impurity_decrease) assert a.min_samples_leaf == b.min_samples_leaf, (a.min_samples_leaf, b.min_samples_leaf) assert a.min_samples_split == b.min_samples_split, (a.min_samples_split, b.min_samples_split) assert a.min_weight_fraction_leaf == b.min_weight_fraction_leaf, ( a.min_weight_fraction_leaf, b.min_weight_fraction_leaf, ) assert a.n_classes_ == b.n_classes_, (a.n_classes_, b.n_classes_) assert a.n_estimators == b.n_estimators, (a.n_estimators, b.n_estimators) assert a.n_features_in_ == b.n_features_in_, (a.n_features_in_, b.n_features_in_) assert a.n_jobs == b.n_jobs, (a.n_jobs, b.n_jobs) assert a.n_outputs_ == b.n_outputs_, (a.n_outputs_, b.n_outputs_) assert a.oob_score == b.oob_score, (a.oob_score, b.oob_score) assert a.random_state == b.random_state, (a.random_state, b.random_state) assert a.verbose == b.verbose, (a.verbose, b.verbose) assert a.warm_start == b.warm_start, (a.warm_start, b.warm_start) if hasattr(a, "feature_names_in_") or hasattr(b, "feature_names_in_"): assert a.feature_names_in_ == b.feature_names_in_, (a.feature_names_in_, b.feature_names_in_) for a_tree, b_tree in zip(a.estimators_, b.estimators_): verify_saved_decision_tree(a_tree, b_tree)