[Python] 基于 flask 构建 Web API 实现参数注入和校验

发布时间 2023-11-03 15:24:46作者: 我爱我家喵喵

在 python 中,flask 包是一个轻量级的 WEB 框架,常用于快速构建 HTTP 服务。

但它并没有提供参数校验和注入的功能。习惯了 java 等高级编程语言开发 web api 的同学,应该都不想每定义一个api都要写很多代码去做校验和获取请求参数吧,至少我是这样。

幸运的是,已经有人提供了参数校验相关的包,可以通过注解的形式完成这些需求。

但我有了一个新的尝试,那就是简单的扩展 @app.route 注解来实现。下面的代码,只是初步的功能实现,还有太多需要完善的地方,但对我目前的需求来说,已经够用了。

扩展 @app.route 注解功能

核心代码如下:

import json
import inspect
import typing as t

from flask import Flask, jsonify, request, url_for, render_template, typing as ft, Response
from flask_cors import CORS
from datetime import datetime
from typing import Union, List, Dict, get_origin

def init_server(port: int):
    r"""
    启动HTTP服务
    """
    app = HttpServer(__name__)
    print("允许跨域")
    CORS(app, resources={r"/*": {"origins": "*"}})
    # 解决jsonify中文乱码
    app.config['JSON_AS_ASCII'] = False

    @app.errorhandler(Exception)
    def special_exception_handler(error):
        print('服务器错误', error)
        return jsonify(R.error(str(error), None, 500).json()), 500

    @app.errorhandler(RequestArgsError)
    def special_exception_args_handler(error):
        return jsonify(R.error(str(error), None, 500).json()), 500

    # 拦截器
    @app.before_request
    def req_before():
        if request.path in app.rule_map:
            # token校验等操作
            # 自动参数注入
            rule = app.rule_map[request.path.lower()]
            args: list[any] = []
            isPost = request.method == 'POST'

            body: Dict[str, any] = request.get_json() if isPost else {}
            if (isPost and body is None):
                body = request.form
            print("body", body)

            for k, parameter in rule.parameters:
                if not rule.args is None and k.lower() in rule.args:
                    k = rule.args[k.lower()]
                value = None
                name = k

                if not isPost:
                    value = request.args.get(name)
                elif isPost:
                    value = body.get(name)
                    if (value is None):
                        value = request.args.get(name)

                if (value is None or (type(value) is str and value == "")):
                    value = None if parameter.default is inspect.Parameter.empty else parameter.default
                    if (rule.require.get(name, False) == True):
                        raise RequestArgsError('参数【' + name + '】不能为空')
                else:
                    try:
                        if (parameter.annotation == int):
                            value = int(value)
                        elif (parameter.annotation == str):
                            value = str(value)
                        elif (parameter.annotation == Dict[str, any] or parameter.annotation == Dict[str, str]):
                            value = json.loads(str(value))
                        else:
                            origin = get_origin(parameter.annotation)
                            if (origin is dict):
                                if (type(value) is str):
                                    value = json.loads(str(value))
                            else:
                                print('parameter origin', origin, origin is dict, value, type(value).__name__)
                    except Exception as e:
                        raise RequestArgsError('参数【' + name + '】非法:' + str(value))
                args.append(value)
                print(k, parameter.annotation, type(value).__name__, value)
            # func = getattr(rule.owner, rule.func.__name__)
            resp = rule.func(*args)
            print('resp', resp, resp is object)
            if (not resp is None and isinstance(resp, R)):
                resp = resp.json()
                return json.dumps(resp, cls=DateEncoder)
            return resp

    # 初始化API控制器
    controller = APIController(app)

    if __name__ == '__main__':
        print("开启HTTP服务 (port: " + str(port) + ")")
        app.run('0.0.0.0', port)

日期格式化、异常类、响应类声明:

class DateEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, datetime):
            return obj.strftime("%Y-%m-%d %H:%M:%S")
        return super().default(obj)

class RequestArgsError(OSError): ...

class R:
    code: int
    msg: str | None
    data: any

    def __init__(self, code: int, data: any, msg: str | None = None):
        self.code = code
        self.data = data
        self.msg = msg

    def json(self):
        return {
            "code": self.code,
            "msg": self.msg,
            "data": self.data
        }

    @staticmethod
    def success(data: any = None):
        return R(0, data)

    @staticmethod
    def error(msg: str | None = None, data: any = None, code: int = -1):
        return R(code, data, msg)

路由规则配置:

class RuleConfg:
    func: any
    parameters: any
    args: Dict[str, str] = {}
    require: Dict[str, bool] = {}

    def __init__(self, view_func: any, args: Dict[str, str] = {}, require: list[str] = []):
        self.func = view_func
        self.args = {key.lower(): value for key, value in args.items()}
        self.parameters = inspect.signature(view_func).parameters.items()
        for v in require:
            self.require[v.lower()] = True

T_route = t.TypeVar("T_route", bound=ft.RouteCallable)

扩展 Flask,提供一个新的注解 @app.routex 来实现参数校验和注入。

class HttpServer(Flask):
    executor = ThreadPoolExecutor()
    rule_map: Dict[str, RuleConfg] = {}

    def routex(self, rule: str, args: object | None = None, require: list[str] = [], **options: t.Any) -> t.Callable[[T_route], T_route]:
        r"""
        `rule`: 路由配置
        `args`: 参数映射配置
        `require`: 必须要的参数
        """
        def decorator(f: T_route) -> T_route:
            # add_url_rule('/add', server, HttpServer.add, {'text': 'txt'}, ['txt', 'data'], methods=['GET','POST'])
            self.rule_map[rule.lower()] = RuleConfg(f, args, require)
            endpoint = rule.replace('/', '_')
            print('注册路由:', rule, endpoint, f)
            self.add_url_rule(rule, endpoint, f, **options)
            return f
        return decorator

编写 Controller

# API 控制器
class APIController:
    app: HttpServer

    def __init__(self, app: HttpServer):
        self.app = app

        @app.routex('/add', args={}, require=['q', 'a'], methods=['GET','POST'])
        def add(q: str, a: str, cls: str, keys: list[str] = [], imgs: list[Dict[str, any]] = []) -> R:
            r"""
            添加问答数据
            """
            return R.error("添加失败")

        @app.routex('/query', args={}, require=['q'], methods=['GET','POST'])
        def query_answer(q: str, topK: int = 5, minScore: int = 0.3, useGpt: bool = False) -> R:
            r"""
            查询问题
            """
            return R.success('ok')

        @app.routex('/query_qa_data', args={}, require=[], methods=['GET','POST'])
        def user_query_qa_data(page: int = 1, pageSize: int = 50, src: int = 0, cls: str = None, searchKey: str = None) -> R:
            r"""
            查询问答知识库
            """
            return R.success(query_qa_data(page, pageSize, src, cls, searchKey))

        @app.routex('/test', args={}, require=[], methods=['GET','POST'])
        def test(src: int = 0) -> R:
            return R.success(src)

初始化 HTTP 服务

print("【启动服务】")
http_server_port = 8080
init_server(http_server_port)