文章

53 · 类型系统进阶:Protocol、TypeVar 与泛型编程

#055 · 2026-04-17 · Python

🔗 知识图谱导航:阅读本文前,建议先回顾《07 · 面向对象:封装、继承与系统解耦》里的接口思想,以及《08 · 模块化构建:标准库与自定义包》里的工程组织方式。本文会把“代码能跑”推进到“代码能被工具理解和长期维护”。 NexDo Time · 2026-04-17 · 预计阅读 30 分钟

痛点与架构

基础类型提示能告诉我们变量是 intstrlist[str],但真实项目里经常需要表达更复杂的关系:一个函数接收什么类型就返回什么类型;一个对象只要实现了几个方法就能当数据库驱动;一个缓存类既要约束 key,也要约束 value。

高级类型系统的价值不是让代码变花哨,而是让 IDE、类型检查器和未来维护者更早发现错误。你可以把它理解成工程里的“施工图纸”:运行时不一定依赖它,但团队协作时少不了它。

TypeVar  -> 保留输入输出类型关系
Protocol -> 定义结构化接口,不强制继承
Generic  -> 写可复用、带类型记忆的容器
@overload -> 给同名函数提供多套类型签名
Literal  -> 限定参数只能取固定值

步步为营:核心逻辑自适应拆解

类型系统比较抽象,所以这一篇拆成 9 个小步骤。每一步都先看一段文末源码里的真实片段,再运行一个有输出的小演示,先建立手感,再理解类型设计。

Step 1:用 TypeVar 保留输入和输出之间的类型关系

痛点与机制

普通函数只写 list,工具很难知道列表里到底是什么。TypeVar 像给类型贴了一个临时标签:传入 list[int]first() 就返回 int | None;传入 list[str],就返回 str | None。它让 IDE 和类型检查器能跟着数据一路追踪。

核心源码(逐字来自文末完整源码)

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")


def first(items: list[T]) -> T | None:
    """泛型函数:返回列表第一个元素,保留类型信息。"""
    return items[0] if items else None


def batch_process(items: list[T], fn: Callable[[T], V]) -> list[V]:
    """泛型高阶函数:对列表每个元素应用函数,保留输入输出类型关联。"""
    return [fn(item) for item in items]

可运行演示(补齐 Mock 数据与 print 反馈)

from typing import Callable, TypeVar

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")


def first(items: list[T]) -> T | None:
    """泛型函数:返回列表第一个元素,保留类型信息。"""
    return items[0] if items else None


def batch_process(items: list[T], fn: Callable[[T], V]) -> list[V]:
    """泛型高阶函数:对列表每个元素应用函数,保留输入输出类型关联。"""
    return [fn(item) for item in items]

numbers = [3, 1, 4]
words = ["task_a", "task_b"]
print("first(numbers):", first(numbers), "类型:", type(first(numbers)).__name__)
print("first(words):", first(words), "类型:", type(first(words)).__name__)
print("平方处理:", batch_process(numbers, lambda x: x * x))
print("转大写:", batch_process(words, str.upper))

Step 2:用 Protocol 定义“会这些方法就能用”的接口契约

痛点与机制

Protocol 是静态版鸭子类型:不要求类继承某个父类,只要方法长得对,就算满足接口。它像插座标准,SQLiteDriver 和 InMemoryDriver 不同厂家生产,但插头形状一致,就都能插上。

核心源码(逐字来自文末完整源码)

@runtime_checkable  # 允许 isinstance() 检查
class DBDriver(Protocol):
    """
    数据库驱动协议:任何实现了这些方法的类都满足此协议,
    无需显式继承——这就是"结构子类型"。
    """

    def execute(self, sql: str, params: tuple = ()) -> Any: ...
    def fetchall(self) -> list[tuple]: ...
    def commit(self) -> None: ...
    def close(self) -> None: ...


class SQLiteDriver:
    """SQLite 实现——没有继承 DBDriver,但满足其协议。"""

    def __init__(self, db_path: str = ":memory:") -> None:
        self._conn = sqlite3.connect(db_path)
        self._cur = self._conn.cursor()

    def execute(self, sql: str, params: tuple = ()) -> "SQLiteDriver":
        self._cur.execute(sql, params)
        return self

    def fetchall(self) -> list[tuple]:
        return self._cur.fetchall()

    def commit(self) -> None:
        self._conn.commit()

    def close(self) -> None:
        self._conn.close()


class InMemoryDriver:
    """内存字典实现——同样满足 DBDriver 协议,用于测试。"""

    def __init__(self) -> None:
        self._store: list[tuple] = []
        self._last_result: list[tuple] = []

    def execute(self, sql: str, params: tuple = ()) -> "InMemoryDriver":
        # 极简模拟:只支持 INSERT
        if sql.strip().upper().startswith("INSERT"):
            self._store.append(params)
        elif sql.strip().upper().startswith("SELECT"):
            self._last_result = self._store[:]
        return self

    def fetchall(self) -> list[tuple]:
        return self._last_result

    def commit(self) -> None:
        return None  # 内存存储无需提交

    def close(self) -> None:
        self._store.clear()

可运行演示(补齐 Mock 数据与 print 反馈)

import sqlite3
from typing import Any, Protocol, runtime_checkable

@runtime_checkable  # 允许 isinstance() 检查
class DBDriver(Protocol):
    """
    数据库驱动协议:任何实现了这些方法的类都满足此协议,
    无需显式继承——这就是"结构子类型"。
    """

    def execute(self, sql: str, params: tuple = ()) -> Any: ...
    def fetchall(self) -> list[tuple]: ...
    def commit(self) -> None: ...
    def close(self) -> None: ...


class SQLiteDriver:
    """SQLite 实现——没有继承 DBDriver,但满足其协议。"""

    def __init__(self, db_path: str = ":memory:") -> None:
        self._conn = sqlite3.connect(db_path)
        self._cur = self._conn.cursor()

    def execute(self, sql: str, params: tuple = ()) -> "SQLiteDriver":
        self._cur.execute(sql, params)
        return self

    def fetchall(self) -> list[tuple]:
        return self._cur.fetchall()

    def commit(self) -> None:
        self._conn.commit()

    def close(self) -> None:
        self._conn.close()


class InMemoryDriver:
    """内存字典实现——同样满足 DBDriver 协议,用于测试。"""

    def __init__(self) -> None:
        self._store: list[tuple] = []
        self._last_result: list[tuple] = []

    def execute(self, sql: str, params: tuple = ()) -> "InMemoryDriver":
        # 极简模拟:只支持 INSERT
        if sql.strip().upper().startswith("INSERT"):
            self._store.append(params)
        elif sql.strip().upper().startswith("SELECT"):
            self._last_result = self._store[:]
        return self

    def fetchall(self) -> list[tuple]:
        return self._last_result

    def commit(self) -> None:
        return None  # 内存存储无需提交

    def close(self) -> None:
        self._store.clear()

for driver_cls in [SQLiteDriver, InMemoryDriver]:
    driver = driver_cls()
    print(type(driver).__name__, "满足 DBDriver:", isinstance(driver, DBDriver))
    driver.close()
print("直觉:不看继承关系,只看方法是否齐全。")

Step 3:用 run_task_pipeline 让数据库实现可以自由替换

痛点与机制

业务函数不应该关心底层是 SQLite 还是内存字典,只关心驱动能不能 execute/fetchall/commit/close。这就是依赖接口而不是依赖实现,测试时可以换成 InMemoryDriver,生产时换成 SQLiteDriver。

核心源码(逐字来自文末完整源码)

def run_task_pipeline(driver: DBDriver, tasks: list[dict]) -> None:
    """
    接受任何满足 DBDriver 协议的对象——不关心具体实现。
    这就是 Protocol 的核心价值:解耦接口与实现。
    """
    driver.execute(
        "CREATE TABLE IF NOT EXISTS tasks "
        "(id INTEGER PRIMARY KEY, name TEXT, status TEXT)"
    )
    for task in tasks:
        driver.execute(
            "INSERT INTO tasks (name, status) VALUES (?, ?)",
            (task["name"], task["status"]),
        )
    driver.commit()
    driver.execute("SELECT * FROM tasks")
    rows = driver.fetchall()
    print(f"\n  通过 {type(driver).__name__} 写入并读取 {len(rows)} 条任务")
    for row in rows:
        print(f"    {row}")

可运行演示(补齐 Mock 数据与 print 反馈)

import sqlite3
from typing import Any, Protocol, runtime_checkable

@runtime_checkable  # 允许 isinstance() 检查
class DBDriver(Protocol):
    """
    数据库驱动协议:任何实现了这些方法的类都满足此协议,
    无需显式继承——这就是"结构子类型"。
    """

    def execute(self, sql: str, params: tuple = ()) -> Any: ...
    def fetchall(self) -> list[tuple]: ...
    def commit(self) -> None: ...
    def close(self) -> None: ...


class SQLiteDriver:
    """SQLite 实现——没有继承 DBDriver,但满足其协议。"""

    def __init__(self, db_path: str = ":memory:") -> None:
        self._conn = sqlite3.connect(db_path)
        self._cur = self._conn.cursor()

    def execute(self, sql: str, params: tuple = ()) -> "SQLiteDriver":
        self._cur.execute(sql, params)
        return self

    def fetchall(self) -> list[tuple]:
        return self._cur.fetchall()

    def commit(self) -> None:
        self._conn.commit()

    def close(self) -> None:
        self._conn.close()


class InMemoryDriver:
    """内存字典实现——同样满足 DBDriver 协议,用于测试。"""

    def __init__(self) -> None:
        self._store: list[tuple] = []
        self._last_result: list[tuple] = []

    def execute(self, sql: str, params: tuple = ()) -> "InMemoryDriver":
        # 极简模拟:只支持 INSERT
        if sql.strip().upper().startswith("INSERT"):
            self._store.append(params)
        elif sql.strip().upper().startswith("SELECT"):
            self._last_result = self._store[:]
        return self

    def fetchall(self) -> list[tuple]:
        return self._last_result

    def commit(self) -> None:
        return None  # 内存存储无需提交

    def close(self) -> None:
        self._store.clear()

def run_task_pipeline(driver: DBDriver, tasks: list[dict]) -> None:
    """
    接受任何满足 DBDriver 协议的对象——不关心具体实现。
    这就是 Protocol 的核心价值:解耦接口与实现。
    """
    driver.execute(
        "CREATE TABLE IF NOT EXISTS tasks "
        "(id INTEGER PRIMARY KEY, name TEXT, status TEXT)"
    )
    for task in tasks:
        driver.execute(
            "INSERT INTO tasks (name, status) VALUES (?, ?)",
            (task["name"], task["status"]),
        )
    driver.commit()
    driver.execute("SELECT * FROM tasks")
    rows = driver.fetchall()
    print(f"\n  通过 {type(driver).__name__} 写入并读取 {len(rows)} 条任务")
    for row in rows:
        print(f"    {row}")

tasks = [
    {"name": "采集日志", "status": "done"},
    {"name": "清洗文本", "status": "running"},
]
for driver_cls in [SQLiteDriver, InMemoryDriver]:
    driver = driver_cls()
    run_task_pipeline(driver, tasks)
    driver.close()

Step 4:用 Stack[T] 写一个不丢类型信息的泛型栈

痛点与机制

栈就像一摞盘子,后放上去的先拿走。Stack[int]Stack[str] 用同一套代码,但类型含义不同。泛型的价值是复用结构,同时保留“里面装的是什么”。

核心源码(逐字来自文末完整源码)

class Stack(Generic[T]):
    """
    泛型栈:Stack[int] 和 Stack[str] 是不同的类型。
    Python 3.12+ 可用 `class Stack[T]:` 语法。
    """

    def __init__(self) -> None:
        self._items: list[T] = []

    def push(self, item: T) -> None:
        self._items.append(item)

    def pop(self) -> T:
        if not self._items:
            raise IndexError("Stack is empty")
        return self._items.pop()

    def peek(self) -> T | None:
        return self._items[-1] if self._items else None

    def __len__(self) -> int:
        return len(self._items)

    def __repr__(self) -> str:
        return f"Stack{self._items}"

可运行演示(补齐 Mock 数据与 print 反馈)

from typing import Generic, TypeVar

T = TypeVar("T")

class Stack(Generic[T]):
    """
    泛型栈:Stack[int] 和 Stack[str] 是不同的类型。
    Python 3.12+ 可用 `class Stack[T]:` 语法。
    """

    def __init__(self) -> None:
        self._items: list[T] = []

    def push(self, item: T) -> None:
        self._items.append(item)

    def pop(self) -> T:
        if not self._items:
            raise IndexError("Stack is empty")
        return self._items.pop()

    def peek(self) -> T | None:
        return self._items[-1] if self._items else None

    def __len__(self) -> int:
        return len(self._items)

    def __repr__(self) -> str:
        return f"Stack{self._items}"

int_stack: Stack[int] = Stack()
for item in [10, 20, 30]:
    int_stack.push(item)
print("整数栈:", int_stack)
print("pop:", int_stack.pop())
print("peek:", int_stack.peek())

str_stack: Stack[str] = Stack()
for item in ["采集", "清洗", "训练"]:
    str_stack.push(item)
print("字符串栈:", str_stack)

Step 5:用 TypedCache[K, V] 同时约束键和值的类型

痛点与机制

缓存像一个小抽屉,键负责定位,值负责存内容。TypedCache[str, dict] 表示键是字符串、值是字典。maxsize 满了就淘汰最老条目,演示里可以直接看到 user:1 被挤出去。

核心源码(逐字来自文末完整源码)

class TypedCache(Generic[K, V]):
    """双类型参数泛型类:键值类型独立指定。"""

    def __init__(self, maxsize: int = 128) -> None:
        self._data: dict[K, V] = {}
        self._maxsize = maxsize

    def set(self, key: K, value: V) -> None:
        if len(self._data) >= self._maxsize:
            oldest = next(iter(self._data))
            del self._data[oldest]
        self._data[key] = value

    def get(self, key: K, default: V | None = None) -> V | None:
        return self._data.get(key, default)

    def __len__(self) -> int:
        return len(self._data)

可运行演示(补齐 Mock 数据与 print 反馈)

from typing import Generic, TypeVar

K = TypeVar("K")
V = TypeVar("V")

class TypedCache(Generic[K, V]):
    """双类型参数泛型类:键值类型独立指定。"""

    def __init__(self, maxsize: int = 128) -> None:
        self._data: dict[K, V] = {}
        self._maxsize = maxsize

    def set(self, key: K, value: V) -> None:
        if len(self._data) >= self._maxsize:
            oldest = next(iter(self._data))
            del self._data[oldest]
        self._data[key] = value

    def get(self, key: K, default: V | None = None) -> V | None:
        return self._data.get(key, default)

    def __len__(self) -> int:
        return len(self._data)

cache: TypedCache[str, dict] = TypedCache(maxsize=2)
cache.set("user:1", {"name": "Alice", "role": "admin"})
cache.set("user:2", {"name": "Bob", "role": "user"})
cache.set("user:3", {"name": "Cindy", "role": "guest"})
print("缓存大小:", len(cache))
print("user:1 已被淘汰:", cache.get("user:1"))
print("user:3:", cache.get("user:3"))

Step 6:用 @overload 让同一个函数拥有多个类型签名

痛点与机制

@overload 像给前台贴说明牌:传一个字符串返回一个整数,传字符串列表返回整数列表。运行时真正执行的还是最后那个 parse_value(),但 IDE 会根据重载签名给出更准的提示。

核心源码(逐字来自文末完整源码)


@overload
def parse_value(x: str) -> int: ...
@overload
def parse_value(x: list[str]) -> list[int]: ...

def parse_value(x: str | list[str]) -> int | list[int]:
    """
    @overload 让 IDE 知道:传入 str 返回 int,传入 list[str] 返回 list[int]。
    运行时只有一个实现。
    """
    if isinstance(x, list):
        return [int(v) for v in x]
    return int(x)

可运行演示(补齐 Mock 数据与 print 反馈)

from typing import overload

# ── @overload:函数重载类型提示 ───────────────────────────────
@overload
def parse_value(x: str) -> int: ...
@overload
def parse_value(x: list[str]) -> list[int]: ...

def parse_value(x: str | list[str]) -> int | list[int]:
    """
    @overload 让 IDE 知道:传入 str 返回 int,传入 list[str] 返回 list[int]。
    运行时只有一个实现。
    """
    if isinstance(x, list):
        return [int(v) for v in x]
    return int(x)

single = parse_value("42")
batch = parse_value(["1", "2", "3"])
print("单个字符串 ->", single, type(single).__name__)
print("字符串列表 ->", batch, type(batch).__name__)
print("直觉:IDE 看重载签名,运行时只走最后那个真实实现。")

Step 7:用 Literal 把参数限制在几个固定选项里

痛点与机制

有些参数不是任意字符串,比如文件模式只能是 r/w/a/rb/wbLiteral 像下拉菜单,提前告诉工具哪些值合法,减少把 readwrite 这种错误字符串传进去的机会。

核心源码(逐字来自文末完整源码)

def open_file(
    path: str,
    mode: Literal["r", "w", "a", "rb", "wb"] = "r",
) -> str:
    """mode 参数只能是指定的几个值,IDE 会提示非法值。"""
    return f"打开文件 {path},模式 {mode}"

可运行演示(补齐 Mock 数据与 print 反馈)

from typing import Literal

def open_file(
    path: str,
    mode: Literal["r", "w", "a", "rb", "wb"] = "r",
) -> str:
    """mode 参数只能是指定的几个值,IDE 会提示非法值。"""
    return f"打开文件 {path},模式 {mode}"

print(open_file("data.csv", "r"))
print(open_file("result.bin", "wb"))
print("Literal 像下拉菜单:建议只让 mode 选固定几个合法值。")

Step 8:用 print_type_table 建立高级类型速查地图

痛点与机制

类型系统概念很多,新手容易混。速查表把 Any、Union、Optional、Literal、TypeVar、Protocol、Generic、overload 放在一起,像一张地图,方便回头定位每个工具解决什么问题。

核心源码(逐字来自文末完整源码)

def print_type_table() -> None:
    print("\n  ── 类型系统层次速查 ──────────────────────")
    rows = [
        ("Any",          "关闭类型检查",          "any_val: Any = ..."),
        ("Union[X,Y]",   "X 或 Y(3.10+ 用 X|Y)","val: int | str"),
        ("Optional[X]",  "X 或 None",             "val: str | None"),
        ("Literal[...]", "限定具体值",             "mode: Literal['r','w']"),
        ("TypeVar",      "泛型占位符",             "T = TypeVar('T')"),
        ("Protocol",     "结构子类型接口",         "@runtime_checkable"),
        ("Generic[T]",   "泛型类",                "class Stack(Generic[T])"),
        ("@overload",    "函数重载签名",           "@overload def f(x:str)->int"),
    ]
    print(f"  {'特性':<16} {'用途':<20} {'示例'}")
    print(f"  {'─'*16} {'─'*20} {'─'*30}")
    for feat, desc, example in rows:
        print(f"  {feat:<16} {desc:<20} {example}")

可运行演示(补齐 Mock 数据与 print 反馈)

def print_type_table() -> None:
    print("\n  ── 类型系统层次速查 ──────────────────────")
    rows = [
        ("Any",          "关闭类型检查",          "any_val: Any = ..."),
        ("Union[X,Y]",   "X 或 Y(3.10+ 用 X|Y)","val: int | str"),
        ("Optional[X]",  "X 或 None",             "val: str | None"),
        ("Literal[...]", "限定具体值",             "mode: Literal['r','w']"),
        ("TypeVar",      "泛型占位符",             "T = TypeVar('T')"),
        ("Protocol",     "结构子类型接口",         "@runtime_checkable"),
        ("Generic[T]",   "泛型类",                "class Stack(Generic[T])"),
        ("@overload",    "函数重载签名",           "@overload def f(x:str)->int"),
    ]
    print(f"  {'特性':<16} {'用途':<20} {'示例'}")
    print(f"  {'─'*16} {'─'*20} {'─'*30}")
    for feat, desc, example in rows:
        print(f"  {feat:<16} {desc:<20} {example}")

print_type_table()

Step 9:用 main 把 table/protocol/generic/overload 做成命令入口

痛点与机制

类型教程也要能一键运行。argparse 让读者用 --mode table--mode protocol--mode generic--mode overload 切换演示,不需要改源码。

核心源码(逐字来自文末完整源码)

def main() -> None:
    parser = argparse.ArgumentParser(description="类型系统进阶演示")
    parser.add_argument(
        "--mode",
        choices=["protocol", "generic", "overload", "table", "all"],
        default="all",
    )
    args = parser.parse_args()

    if args.mode in ("table", "all"):
        print_type_table()
    if args.mode in ("protocol", "all"):
        demo_protocol()
    if args.mode in ("generic", "all"):
        demo_generic()
    if args.mode in ("overload", "all"):
        demo_overload()

可运行演示(补齐 Mock 数据与 print 反馈)

import argparse

def main() -> None:
    parser = argparse.ArgumentParser(description="类型系统进阶演示")
    parser.add_argument(
        "--mode",
        choices=["protocol", "generic", "overload", "table", "all"],
        default="all",
    )
    args = parser.parse_args()

    if args.mode in ("table", "all"):
        print_type_table()
    if args.mode in ("protocol", "all"):
        demo_protocol()
    if args.mode in ("generic", "all"):
        demo_generic()
    if args.mode in ("overload", "all"):
        demo_overload()

def print_type_table() -> None:
    print("打印类型速查表")


def demo_protocol() -> None:
    print("运行 Protocol 演示")


def demo_generic() -> None:
    print("运行 Generic 演示")


def demo_overload() -> None:
    print("运行 overload 演示")

import sys
for mode in ["table", "protocol", "generic", "overload"]:
    print(f"\n$ python typing_advanced.py --mode {mode}")
    sys.argv = ["typing_advanced.py", "--mode", mode]
    main()

极客实战:完整源码与运行

现在,把上面的积木拼起来,将下面完整代码保存为 typing_advanced.py。它会用 SQLite 内存库和内存字典驱动演示 Protocol,用泛型栈和泛型缓存演示 Generic,并用 overload/Literal 展示 IDE 级类型约束。

# typing_advanced.py
"""
类型系统进阶演示 —— Protocol/TypeVar/Generic 工程级用法。
用法:
    python3 typing_advanced.py
    python3 typing_advanced.py --mode protocol
    python3 typing_advanced.py --mode generic
    python3 typing_advanced.py --mode overload
"""

import argparse
import sqlite3
from typing import (
    Any, Callable, Generic, Literal, Protocol,
    TypeVar, overload, runtime_checkable,
)

# ── TypeVar ───────────────────────────────────────────────────
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")


def first(items: list[T]) -> T | None:
    """泛型函数:返回列表第一个元素,保留类型信息。"""
    return items[0] if items else None


def batch_process(items: list[T], fn: Callable[[T], V]) -> list[V]:
    """泛型高阶函数:对列表每个元素应用函数,保留输入输出类型关联。"""
    return [fn(item) for item in items]


# ── Protocol:结构子类型(鸭子类型的静态版)─────────────────
@runtime_checkable  # 允许 isinstance() 检查
class DBDriver(Protocol):
    """
    数据库驱动协议:任何实现了这些方法的类都满足此协议,
    无需显式继承——这就是"结构子类型"。
    """

    def execute(self, sql: str, params: tuple = ()) -> Any: ...
    def fetchall(self) -> list[tuple]: ...
    def commit(self) -> None: ...
    def close(self) -> None: ...


class SQLiteDriver:
    """SQLite 实现——没有继承 DBDriver,但满足其协议。"""

    def __init__(self, db_path: str = ":memory:") -> None:
        self._conn = sqlite3.connect(db_path)
        self._cur = self._conn.cursor()

    def execute(self, sql: str, params: tuple = ()) -> "SQLiteDriver":
        self._cur.execute(sql, params)
        return self

    def fetchall(self) -> list[tuple]:
        return self._cur.fetchall()

    def commit(self) -> None:
        self._conn.commit()

    def close(self) -> None:
        self._conn.close()


class InMemoryDriver:
    """内存字典实现——同样满足 DBDriver 协议,用于测试。"""

    def __init__(self) -> None:
        self._store: list[tuple] = []
        self._last_result: list[tuple] = []

    def execute(self, sql: str, params: tuple = ()) -> "InMemoryDriver":
        # 极简模拟:只支持 INSERT
        if sql.strip().upper().startswith("INSERT"):
            self._store.append(params)
        elif sql.strip().upper().startswith("SELECT"):
            self._last_result = self._store[:]
        return self

    def fetchall(self) -> list[tuple]:
        return self._last_result

    def commit(self) -> None:
        return None  # 内存存储无需提交

    def close(self) -> None:
        self._store.clear()


def run_task_pipeline(driver: DBDriver, tasks: list[dict]) -> None:
    """
    接受任何满足 DBDriver 协议的对象——不关心具体实现。
    这就是 Protocol 的核心价值:解耦接口与实现。
    """
    driver.execute(
        "CREATE TABLE IF NOT EXISTS tasks "
        "(id INTEGER PRIMARY KEY, name TEXT, status TEXT)"
    )
    for task in tasks:
        driver.execute(
            "INSERT INTO tasks (name, status) VALUES (?, ?)",
            (task["name"], task["status"]),
        )
    driver.commit()
    driver.execute("SELECT * FROM tasks")
    rows = driver.fetchall()
    print(f"\n  通过 {type(driver).__name__} 写入并读取 {len(rows)} 条任务")
    for row in rows:
        print(f"    {row}")


# ── Generic:泛型类 ───────────────────────────────────────────
class Stack(Generic[T]):
    """
    泛型栈:Stack[int] 和 Stack[str] 是不同的类型。
    Python 3.12+ 可用 `class Stack[T]:` 语法。
    """

    def __init__(self) -> None:
        self._items: list[T] = []

    def push(self, item: T) -> None:
        self._items.append(item)

    def pop(self) -> T:
        if not self._items:
            raise IndexError("Stack is empty")
        return self._items.pop()

    def peek(self) -> T | None:
        return self._items[-1] if self._items else None

    def __len__(self) -> int:
        return len(self._items)

    def __repr__(self) -> str:
        return f"Stack{self._items}"


class TypedCache(Generic[K, V]):
    """双类型参数泛型类:键值类型独立指定。"""

    def __init__(self, maxsize: int = 128) -> None:
        self._data: dict[K, V] = {}
        self._maxsize = maxsize

    def set(self, key: K, value: V) -> None:
        if len(self._data) >= self._maxsize:
            oldest = next(iter(self._data))
            del self._data[oldest]
        self._data[key] = value

    def get(self, key: K, default: V | None = None) -> V | None:
        return self._data.get(key, default)

    def __len__(self) -> int:
        return len(self._data)


# ── @overload:函数重载类型提示 ───────────────────────────────
@overload
def parse_value(x: str) -> int: ...
@overload
def parse_value(x: list[str]) -> list[int]: ...

def parse_value(x: str | list[str]) -> int | list[int]:
    """
    @overload 让 IDE 知道:传入 str 返回 int,传入 list[str] 返回 list[int]。
    运行时只有一个实现。
    """
    if isinstance(x, list):
        return [int(v) for v in x]
    return int(x)


# ── Literal:限定具体值 ───────────────────────────────────────
def open_file(
    path: str,
    mode: Literal["r", "w", "a", "rb", "wb"] = "r",
) -> str:
    """mode 参数只能是指定的几个值,IDE 会提示非法值。"""
    return f"打开文件 {path},模式 {mode}"


# ── 演示函数 ──────────────────────────────────────────────────
def demo_protocol() -> None:
    print("\n  ── Protocol 演示 ─────────────────────────")

    tasks = [
        {"name": "数据采集", "status": "done"},
        {"name": "文本清洗", "status": "running"},
        {"name": "模型推理", "status": "pending"},
    ]

    # 两种驱动都满足 DBDriver 协议
    for driver_cls in [SQLiteDriver, InMemoryDriver]:
        driver = driver_cls()
        run_task_pipeline(driver, tasks)
        driver.close()

    # isinstance 检查(需要 @runtime_checkable)
    sqlite_drv = SQLiteDriver()
    print(f"\n  isinstance(SQLiteDriver(), DBDriver) = {isinstance(sqlite_drv, DBDriver)}")
    sqlite_drv.close()


def demo_generic() -> None:
    print("\n  ── Generic 泛型类演示 ────────────────────")

    # Stack[int]
    int_stack: Stack[int] = Stack()
    for n in [10, 20, 30]:
        int_stack.push(n)
    print(f"  int_stack: {int_stack}")
    print(f"  pop: {int_stack.pop()}, peek: {int_stack.peek()}")

    # Stack[str]
    str_stack: Stack[str] = Stack()
    for s in ["task_a", "task_b", "task_c"]:
        str_stack.push(s)
    print(f"  str_stack: {str_stack}")

    # TypedCache[str, dict]
    cache: TypedCache[str, dict] = TypedCache(maxsize=3)
    cache.set("user:1001", {"name": "Alice", "role": "admin"})
    cache.set("user:1002", {"name": "Bob",   "role": "user"})
    print(f"\n  cache size: {len(cache)}")
    print(f"  cache.get('user:1001'): {cache.get('user:1001')}")


def demo_overload() -> None:
    print("\n  ── @overload 与 Literal 演示 ─────────────")

    # @overload
    single = parse_value("42")
    batch  = parse_value(["1", "2", "3"])
    print(f"  parse_value('42')          = {single!r}  (type: {type(single).__name__})")
    print(f"  parse_value(['1','2','3']) = {batch!r}  (type: {type(batch).__name__})")

    # Literal
    print(f"\n  {open_file('data.csv', 'r')}")
    print(f"  {open_file('output.bin', 'wb')}")

    # TypeVar 演示
    nums: list[int] = [3, 1, 4, 1, 5]
    strs: list[str] = ["task_a", "task_b"]
    print(f"\n  first(nums) = {first(nums)!r}  (type: {type(first(nums)).__name__})")
    print(f"  first(strs) = {first(strs)!r}  (type: {type(first(strs)).__name__})")


def print_type_table() -> None:
    print("\n  ── 类型系统层次速查 ──────────────────────")
    rows = [
        ("Any",          "关闭类型检查",          "any_val: Any = ..."),
        ("Union[X,Y]",   "X 或 Y(3.10+ 用 X|Y)","val: int | str"),
        ("Optional[X]",  "X 或 None",             "val: str | None"),
        ("Literal[...]", "限定具体值",             "mode: Literal['r','w']"),
        ("TypeVar",      "泛型占位符",             "T = TypeVar('T')"),
        ("Protocol",     "结构子类型接口",         "@runtime_checkable"),
        ("Generic[T]",   "泛型类",                "class Stack(Generic[T])"),
        ("@overload",    "函数重载签名",           "@overload def f(x:str)->int"),
    ]
    print(f"  {'特性':<16} {'用途':<20} {'示例'}")
    print(f"  {'─'*16} {'─'*20} {'─'*30}")
    for feat, desc, example in rows:
        print(f"  {feat:<16} {desc:<20} {example}")


def main() -> None:
    parser = argparse.ArgumentParser(description="类型系统进阶演示")
    parser.add_argument(
        "--mode",
        choices=["protocol", "generic", "overload", "table", "all"],
        default="all",
    )
    args = parser.parse_args()

    if args.mode in ("table", "all"):
        print_type_table()
    if args.mode in ("protocol", "all"):
        demo_protocol()
    if args.mode in ("generic", "all"):
        demo_generic()
    if args.mode in ("overload", "all"):
        demo_overload()


if __name__ == "__main__":
    main()
$ python typing_advanced.py --mode table
── 类型系统层次速查 ──────────────────────
  特性               用途                   示例
  ──────────────── ──────────────────── ──────────────────────────────
  Any              关闭类型检查               any_val: Any = ...
  Union[X,Y]       X 或 Y(3.10+ 用 X|Y)   val: int | str
  Optional[X]      X 或 None             val: str | None
  Literal[...]     限定具体值                mode: Literal['r','w']
  TypeVar          泛型占位符                T = TypeVar('T')
  Protocol         结构子类型接口              @runtime_checkable
  Generic[T]       泛型类                  class Stack(Generic[T])
  @overload        函数重载签名               @overload def f(x:str)->int

$ python typing_advanced.py --mode protocol
── Protocol 演示 ─────────────────────────

  通过 SQLiteDriver 写入并读取 3 条任务
    (1, '数据采集', 'done')
    (2, '文本清洗', 'running')
    (3, '模型推理', 'pending')

  通过 InMemoryDriver 写入并读取 3 条任务
    ('数据采集', 'done')
    ('文本清洗', 'running')
    ('模型推理', 'pending')

  isinstance(SQLiteDriver(), DBDriver) = True

$ python typing_advanced.py --mode generic
── Generic 泛型类演示 ────────────────────
  int_stack: Stack[10, 20, 30]
  pop: 30, peek: 20
  str_stack: Stack['task_a', 'task_b', 'task_c']

  cache size: 2
  cache.get('user:1001'): {'name': 'Alice', 'role': 'admin'}

$ python typing_advanced.py --mode overload
── @overload 与 Literal 演示 ─────────────
  parse_value('42')          = 42  (type: int)
  parse_value(['1','2','3']) = [1, 2, 3]  (type: list)

  打开文件 data.csv,模式 r
  打开文件 output.bin,模式 wb

  first(nums) = 3  (type: int)
  first(strs) = 'task_a'  (type: str)

小结与 NexDo Time ⚡

这一篇你掌握了 Python 类型系统的进阶工具:TypeVar 负责追踪类型关系,Protocol 负责定义结构化接口,Generic 负责复用容器并保留元素类型,overload 和 Literal 负责让函数签名更精确。它们不会直接让程序跑得更快,但会让大型项目更容易被检查、重构和协作维护。

5 分钟微操挑战:给 TypedCache 增加一个 delete(key: K) -> bool 方法,删除成功返回 True,key 不存在返回 False。然后在 demo_generic() 里打印删除前后的缓存大小。

Don’t wait for next time, do it in the next moment.