自己写一个 NODE/ATTR 的结构

发布时间 2023-05-21 17:27:05作者: 方头狮
## python 3.8 以上
from typing import Dict, List, TypeVar, Tuple, Generic, get_args

import json

T = TypeVar("T")


# 数据的默认值


def get_dft(tp):
    if issubclass(tp, str):
        return ""
    elif issubclass(tp, int):
        return 0
    elif issubclass(tp, float):
        return 0.0
    elif hasattr(tp, "_user_default_"):  # 如果是自定类, 实现了这个属性, 也可以用
        return getattr(tp, "_user_default_")
    else:
        return None


# 字段的实现, 也做了类型检查,但现在只能是基本类
class base_field(Generic[T]):
    """字段基类"""

    name: str
    tp: type  # 用于类型检查

    def __init__(self, default=None, show=True):
        self.name = ""
        self.tp = None  # 初始化时,类型是没有设定的,这是因为 无法获取到 T 的具体值.
        self._default = default  # 默认值
        self.is_show = show  # to dict 时, 是否显示出来

    def __get__(self, instance, owner) -> T:
        """从 instance 的 __field_data__ 里取值, 并返回"""

        value_con = getattr(instance, "__field_data__", None)
        if value_con is None:
            setattr(instance, "__field_data__", {})
            value_con = getattr(instance, "__field_data__")
        dd = value_con.get(self.name, None)

        # 当取值是NONE 时, 试着获取默认值
        if dd is None and not instance.__field_default__[self.name]:
            dd = self._default
        if dd is None:
            print(
                f":: WARN :: value is None : {instance.__class__.__name__}(). {self.name}"
            )
        return dd

    def __set__(self, instance, value):
        """往 instance 的 __field_data__ 里填值"""
        value_con = getattr(instance, "__field_data__", None)
        if value_con is None:
            setattr(instance, "__field_data__", {})
            value_con = getattr(instance, "__field_data__")

        if isinstance(value, self.tp):
            value_con[self.name] = value
        else:
            raise TypeError(f"应该填入 {self.tp},实填{type(value)}:{str(value)[:30]}")
        if not instance.__field_default__[self.name]:
            if self._default is None:
                self._default = get_dft(self.tp)
            instance.__field_default__[self.name] = True

    def __set_name__(self, owner, name):
        """在类初始化时调用这个函数,\n
        把自己放到类的 __field_setting__ 里去,\n
        获取到具体的 T 的类型, 把它绑定到 self.tp 上.
        """
        assert self.name == "", f"-> 字段已经有名称,不能重复命名 :now <{self.name}> , {name} "
        assert str(name).startswith("a_")  # 字段强制 用 a_ 开头
        self.name = name
        __field_setting__ = getattr(owner, "__field_setting__", None)
        if __field_setting__ is None:
            setattr(owner, "__field_setting__", {})
            __field_setting__ = getattr(owner, "__field_setting__")
        ff = __field_setting__.get(name, None)
        if ff is not None:
            raise NameError(f"只能有一个同名字段.{name}")
        else:
            self.name = name
            __field_setting__[name] = self
            self.tp = get_args(self.__orig_class__)[0]  # ##


class Attr(base_field[T]):
    """
    Attr(default=None, show=True) \n
    用字段 实现 属性定义
    """


class node_meta(type):
    """
    元类 用于控制 NODE 初始化时的一些动作
    """

    _class_list = {}
    _fn_temp = lambda: print("没找到函数名称")

    def __new__(cls, name, bases, attrs):
        # 初始化时, 把自己的  __field_setting__ 与父类的 __field_setting__ 隔离开
        if name == "node_base":
            fs = {}
        else:
            fs = {}
            for i in bases:
                for n, fds in i.__field_setting__.items():
                    fs[n] = fds
        attrs["__field_setting__"] = fs

        # --

        # 生成类
        obj_tp = super().__new__(cls, name, bases, attrs)

        node_meta._class_list[name] = obj_tp

        # 初始化默认值的读取标志
        if getattr(obj_tp, "__field_default__", None) is None:
            setattr(obj_tp, "__field_default__", {})

        for ff, vv in attrs.items():
            if ff.startswith("a_"):
                obj_tp.__field_default__[ff] = False

        # 执行子类初始化钩子
        getattr(obj_tp, "__on_sub_class__", cls._fn_temp)()

        return obj_tp


class node_base(metaclass=node_meta):
    """NODE 的 基类"""

    @classmethod
    def __on_sub_class__(cls):
        """子类钩子"""
        pass

    @classmethod
    def __all_sub_classes__(cls):
        return node_meta._class_list

    def to_dict(self, all_show=False):
        """导出 dict形式,用于传输"""
        raise NotImplementedError

    @classmethod
    def from_dict(cls, d: "List") -> "node_base":
        """导入 dict形式,用于传输"""
        assert isinstance(d, List)
        assert isinstance(d[0], Dict)

        t1: Dict = d[0]
        assert t1.get("_node_type_", None) is not None
        nd: type = node_meta._class_list[t1["_node_type_"]]
        return nd.from_dict(d)

    def __str__(self):
        return str(self.to_dict())

    __repr__ = __str__


class Node_desc(node_base):
    """这是一种描述用的类,不可以添加子结点"""

    a_count = Attr[int](0)

    def to_dict(self, all_show=False):
        r = [
            {
                "_node_type_": self.__class__.__name__,
                "_attrs": {
                    k: getattr(self, k)
                    for k, v in self.__field_setting__.items()
                    if all_show or v.is_show
                },
            }
        ]
        return r

    @classmethod
    def from_dict(cls, d: List) -> node_base:
        assert isinstance(d, List)
        assert isinstance(d[0], Dict)

        t1: Dict = d[0]
        assert t1.get("_node_type_", None) is not None
        ndtp: type = node_meta._class_list[t1["_node_type_"]]
        if issubclass(ndtp, cls):
            ret = ndtp()
            atrs = t1.get("_attrs", {})
            for k, v in atrs.items():
                setattr(ret, k, v)
            return ret
        raise TypeError("NODE DESC 数据格式不对")


class Node(node_base):
    """常规结点,有属性,  他可以添加子结点, 并为子结点指定一个描述符,记录不同的状态"""

    desc_type = Node_desc
    a_name = Attr[str]("")

    def __init__(self) -> None:
        self.__field_default__ = {}
        for i, v in self.__class__.__field_default__.items():
            self.__field_default__[i] = False

        self.__data: List["Node"] = []
        self._desc_dict: Dict[str, Node_desc] = {}
        self._names_of_child: List[str] = []

    def __getitem__(self, key) -> "Node":
        return self.__data[key]

    def __setitem__(self, key, v: "Node"):
        assert isinstance(v, Node)
        self.__data[key] = v
        if v.a_name in self._names_of_child:
            self._desc_dict[v.a_name] = self.desc_type()
            self._desc_dict[v.a_name].a_count = 1
        self._names_of_child = [i.a_name for i in self.__data]

    def append(self, item: "Node", desc: Node_desc = None):
        if desc is None:
            assert isinstance(item, Node)
            if item.a_name in self._names_of_child:
                self._desc_dict[item.a_name].a_count += 1
            else:
                self.__data.append(item)
                self._desc_dict[item.a_name] = self.desc_type()
                self._desc_dict[item.a_name].a_count = 1
            self._names_of_child = [i.a_name for i in self.__data]
        else:
            assert isinstance(item, Node)
            assert item.a_name not in self._names_of_child
            assert isinstance(desc, self.desc_type)
            self.__data.append(item)
            self._desc_dict[item.a_name] = desc
            self._names_of_child = [i.a_name for i in self.__data]

    def remove(self, item: "Node"):
        assert isinstance(item, Node)
        if item.a_name in self._names_of_child:
            self._desc_dict.pop(item.a_name)
            for i in self.__data:
                if i.a_name == item.a_name:
                    self.__data.remove(i)
                    break

    def index(self, item: "Node"):
        assert isinstance(item, Node)
        r = 0
        for i in self.__data:
            if i.a_name == item.a_name:
                return i
            r += 1
        return -1

    def pop(self, index: "Node | int" = -1):
        if isinstance(index, int):
            nod = self[index]
        elif isinstance(index, Node):
            nod = self[self.index(index)]
        cc = self._desc_dict[nod.a_name].a_count
        if cc == 1:
            self.remove(nod)
            return nod
        elif cc > 1:
            self._desc_dict[nod.a_name].a_count -= 1
            return nod
        else:
            raise ValueError(f"无法 POP ,数量过少. <{self.a_name}>")

    def __len__(self):
        return len(self.__data)

    def to_dict(self, all_show=False):
        rr = [
            {
                "_node_type_": self.__class__.__name__,
                "_attrs": {
                    k: getattr(self, k)
                    for k, v in self.__field_setting__.items()
                    if all_show or v.is_show
                },
            }
        ]

        for i in self.__data:
            rr.append(
                [self._desc_dict[i.a_name].to_dict(all_show), i.to_dict(all_show)]
            )

        return rr

    @classmethod
    def from_dict(cls, d: "List"):
        assert isinstance(d, List)
        assert isinstance(d[0], Dict)

        t1: Dict = d[0]
        assert t1.get("_node_type_", None) is not None
        ndtp = node_meta._class_list[t1["_node_type_"]]

        if issubclass(ndtp, Node):
            ret = ndtp()
            atrs = t1.get("_attrs", {})
            for k, v in atrs.items():
                setattr(ret, k, v)
            children = d[1:]
            dsc_tp = ndtp.desc_type
            for c in children:
                dsc = dsc_tp.from_dict(c[0])
                # print(c[1])
                c_chd = Node.from_dict(c[1])
                ret.append(c_chd, dsc)
            return ret
        else:
            raise TypeError("NODE 数据格式不对")


def node_from_dict(data):
    return node_base.from_dict(data)


def node_to_json(n: Node):
    rr = n.to_dict(True)

    return json.dumps(rr, ensure_ascii=False)


def node_from_json(data: str):
    return node_from_dict(json.loads(data))


if __name__ == "__main__":

    class A(Node):
        a_width = Attr[int](50)

    class B(Node):
        a_deep = Attr[int](65)

    class AA(A):
        a_height = Attr[int](15, False)

    a = A()
    aa = AA()
    c = A()
    print(aa.a_width)

    a.a_name = "a_1"
    aa.a_name = "a_2"
    c.a_name = "a_3"
    a.a_width = 30
    aa.a_width = 20
    aa.a_height = 40

    a.append(c)
    aa.append(a)

    xx = aa.to_dict(True)

    print(xx)

    yy = node_base.from_dict(xx)
    print(yy.to_dict(True))