FastAPI工程化项目案例

2023-08-21 17:19:55

本案例使用Python语言编写,业务场景为简单的用户操作,整个项目是对FastAPI第三方库的工程化实现。由于博主技术水平有限,如果你在本地测试过程中发现存在BUG,纯属正常,感谢您的理解!!!(本案例使用Python語言編寫,業務場景為簡單的用戶操作,整個項目是對FastAPI第三方庫的工程化實現。由於博主技術水平有限,如果你在本地測試過程中發現存在BUG,純屬正常,感謝您的理解! ! !)

1.项目结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|--app
| |--apiv1
| | |--endpoint
| | | | __init__.py
| | | | user.py
| | | __init__.py
| | | router.py
| |--common
| | | __init__.py
| | | constant.py
| | | encrypto.py
| | | utils.py
| |--core
| | |--init
| | | | __init__.py
| | | | cors.py
| | | | database.py
| | | | exception.py
| | | | grequest.py
| | | | logger.py
| | | | middleware.py
| | | | router.py
| | | __init__.py
| | | codeenum.py
| | | httpresp.py
| | | localproxy.py
| | | serialize.py
| |--model
| | | __init__.py
| | | base.py
| | | user.py
| |--schema
| | | __init__.py
| | | user.py
| |--service
| | | __init__.py
| | | user.py
| | __init__.py
| .env
| __init__.py
| config.py
| main.py
2.模块实现
2.1公共模块
  • 常量集定义
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
DEFAULT_PAGE = 1
DEFAULT_PER_PAGE = 10
DEFAULT_FAIL = -1

CACHE_FIVE_SECONDS = 5
CACHE_MINUTE = 60
CACHE_THREE_MINUTE = 60 * 3
CACHE_FIVE_MINUTE = 60 * 5
CACHE_TEN_MINUTE = 60 * 10
CACHE_HALF_HOUR = 60 * 30
CACHE_HOUR = 60 * 60
CACHE_THREE_HOUR = 60 * 60 * 3
CACHE_TWELVE_HOUR = 60 * 60 * 12
CACHE_DAY = 60 * 60 * 24
CACHE_WEEK = 60 * 60 * 24 * 7
CACHE_MONTH = 60 * 60 * 24 * 30

TEST_USER_INFO = 'zer0py2c:user_token:{0}'
TEST_EXECUTE_SET = 'zer0py2c:test_execute_set:case:{}'
TEST_EXECUTE_STATS = 'zer0py2c:test_execute_set:stats:{}'
TEST_EXECUTE_TASK = 'zer0py2c:test_execute_set:task:{}'
TEST_EXECUTE_PARAMETER = 'zer0py2c:test_execute_set:extract_parameter:{}'
DATA_STRUCTURE_CASE_UPDATE = 'zer0py2c:data_structure:user:{}'
TEST_USER_LOGIN_TIME = 'zer0py2c:user_login_time:{}'

PREFORMANCE_RUN_STATUS = 'performance_test:status'
PREFORMANCE_FREE = 0
PREFORMANCE_INIT = 10
PREFORMANCE_BUSY = 20
PREFORMANCE_ABORT = 30

PREFORMANCE_CODE = 'code'
PREFORMANCE_SIGN_CODE = 'sign_code'

THREAD_MAXMUM = 100
RUN_NUMBER_MAXMUM = 1000000
DEBUG_MAXMUM = 100
  • 加密与解密
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import base64

from Cryptodome import Random
from Cryptodome.Cipher import PKCS1_v1_5
from Cryptodome.PublicKey import RSA

PUBLIC_KEY = """-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC2YZVJzRrn1kyJHZS+7O5/oteO
YOkbiNk3ndRLQscgdDf3k+RaRomzvHro5w2h6T9A5rd45vM0kyKcBezE/Za1pOKq
meovah1zxxoofQJ8k91ybVFXYJx99k9ravCMr+wKuCpuuwPe8he10iBZ465vVZ6g
5Nbg4gM2PcV7OMVLaQIDAQAB
-----END PUBLIC KEY-----"""

PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQC2YZVJzRrn1kyJHZS+7O5/oteOYOkbiNk3ndRLQscgdDf3k+Ra
RomzvHro5w2h6T9A5rd45vM0kyKcBezE/Za1pOKqmeovah1zxxoofQJ8k91ybVFX
YJx99k9ravCMr+wKuCpuuwPe8he10iBZ465vVZ6g5Nbg4gM2PcV7OMVLaQIDAQAB
AoGABgNhutMngjyVcta4omgOhS3jLcjJ8sbrA4Y220w5DhvALc+7XBMejpWmAfMT
8YekAWGsq7CwqjDON7Gge3kRdz7PDwjaPBwkOebD1aYNWDM0TfQiINVxCkZpPoKg
KTpIELQUoD6KMWw8NUwcasqHcz1HCC6DnRYpG3XJXYhdJDECQQDCvQ/tkHOi92He
RioTHSiJd/5TvgPgBH7dqsldT6mwS67EbrWFEiSSRbzref6wv+r8sXb9d436Ltno
4lngQWZZAkEA78FX7EzS/TAV/PDfWh/ncozY9tFqfPNk4w96LVb5wy8oc9M419K9
yLWeSfiBcXK2l+S2XYk49OhuznklZWiLkQJAL1c+1AXV1rxE8oAkIlloTWL6VOlQ
j9kH7mNiaGjBW7ZKWj5/qkXq1hRWBPi3TciaG6wYvS2fOj7BgrfkGXxMoQJBAIbT
zNUHEvPtMcBP2Nr+7BJgILcUZ3UjDw4dqxCKQ+S+xVn1Y5cDXVTcxcpFZM3eu85J
gUCypYQcngug1yXjF/ECQQCBZZAZ+GhpzqwerwqyfNHvrahSrfp14l6STktaCjKy
IR4n5TomCkHRaeXPgn1YVIhz5/LaVZJuKK3eiN2Wbdwy
-----END RSA PRIVATE KEY-----"""


def encrypto_rsa_password(password):
try:
public_key = RSA.import_key(PUBLIC_KEY)
cipher = PKCS1_v1_5.new(public_key)
text = cipher.encrypt(password.encode("utf8"))
return base64.b64encode(text)
except Exception:
return password


def decrypto_rsa_password(password):
try:
private_key = RSA.import_key(PRIVATE_KEY)
cipher = PKCS1_v1_5.new(private_key)
text = cipher.decrypt(base64.b64decode(password), b"")
return text.decode()
except Exception:
return password
  • 辅助工具
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import re
import os
import hashlib
import zipfile
import uuid
import typing

from web.backend.app.core import g
from web.backend.app.core.init.database import init_redis_pool
from web.backend.app.core.init.exception import AccessTokenFail
from web.backend.app.common.constant import TEST_USER_INFO


def get_str_uuid():
return str(uuid.uuid4()).replace("-", "")


async def current_user(token: str = None) -> typing.Union[typing.Dict[typing.Text, typing.Any], None]:
if not g.redis:
g.redis = await init_redis_pool()
user = await g.redis.get(TEST_USER_INFO.format(g.token if not token else token))
if not user:
raise AccessTokenFail()
return user


class MZipFile(object):

def __init__(self, zip_path):
self.zip = zipfile.ZipFile(zip_path, 'r')

def get_filecount(self):
return len(self.zip.namelist())

def get_one_file(self):
for name in self.zip.namelist():
yield self.read_lines(name)

def read_lines(self, name):
return [line.decode() for line in self.zip.open(name).readlines()]

def get_filenames(self):
return self.zip.namelist()

def extract_to(self, path):
self.zip.extractall(path)
return path

@staticmethod
def cal_hash_by_file(path, algorithm):
size = os.path.getsize(path)
with open(path, 'rb') as f:
while size >= 1024 * 1024:
algorithm.update(f.read(1024 * 1024))
size -= 1024 * 1024
algorithm.update(f.read())
return algorithm.hexdigest()


class MD5Utils:

@staticmethod
def read_file_by_zip(zip_name, file_name, mode="w"):
zip = MZipFile(zip_path=zip_name)
pattern = zip_name.split(".")[0]
pattern = f"{pattern}/([0-9a-z]+)\n?"
with open(file_name, mode, encoding="utf-8") as fp:
for name in zip.get_filenames():
name = name.strip()
s = re.match(pattern, name)
if s:
fp.write(s.group(1) + "\n")

@staticmethod
def list_file_by_zip(zip_name, file_name, mode="w"):
zip = MZipFile(zip_path=zip_name)
pattern = zip_name.split(".")[0]
pattern = f"{pattern}/([0-9a-z]+)\n?"
hash_list = []
for name in zip.get_filenames():
name = name.strip()
s = re.match(pattern, name)
if s:
hash_list.append(s.group(1))
return hash_list

@staticmethod
def cal_file_hash(file_name):
md5, sha1, sha256 = MZipFile.cal_hash_by_file(file_name, hashlib.md5()), MZipFile.cal_hash_by_file(
file_name, hashlib.sha1()), MZipFile.cal_hash_by_file(file_name, hashlib.sha256())
return md5, sha1, sha256
2.2核心模块
  • 业务状态码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from enum import Enum


class CodeEnum(Enum):

@property
def code(self):
return self.value[0]

@property
def msg(self):
return self.value[1]

PARTNER_CODE_OK = (0, "SUCCESS")
PARTNER_CODE_FAIL = (-1, "FAILED")

WRONG_USER_NAME_OR_PASSWORD = (10001, "账号或者密码错误!")
PARTNER_CODE_EMPLOYEE_FAIL = (10002, "账号错误!")
WRONG_USER_NAME_OR_PASSWORD_LOCK = (10003, "密码输入错误超过次数,请5分钟后再登录!")
USERNAME_OR_EMAIL_IS_REGISTER = (10004, "用户名已被注册")
USER_ID_IS_NULL = (10005, "用户id不能为空")
PASSWORD_TWICE_IS_NOT_AGREEMENT = (10006, "两次输入的密码不一致")
NEW_PWD_NO_OLD_PWD_EQUAL = (10007, "新密码不能与旧密码相同")
OLD_PASSWORD_ERROR = (10008, "旧密码错误")

PARTNER_CODE_TOKEN_EXPIRED_FAIL = (11000, "用户信息已过期")

PARTNER_CODE_PARAMS_FAIL = (12000, "必填参数不能为空")

PROJECT_HAS_MODULE_ASSOCIATION = (13000, "项目有模块或用例关联,不能删除")
PROJECT_NAME_EXIST = (13001, "项目名已存在")

MODULE_HAS_CASE_ASSOCIATION = (14000, " 模块有用例关联, 请删除对于模块下的用例")
MODULE_NAME_EXIST = (14001, "模块名已存在")

CASE_NAME_EXIST = (15000, "用例名已存在,请重新命名")
SUITE_NAME_EXIST = (15001, "套件名已存在,请重新命名")
CASE_NOT_EXIST = (15002, "用例不存在")
CASE_UPLOAD_FROM_POSTMAN = (15003, "导入失败")

MENU_HAS_MODULE_ASSOCIATION = (16000, "当前菜单下管理的子菜单,不能删除!")
MENU_NAME_EXIST = (16001, "菜单名称已存在")

LOOKUP_CODE_NOT_EMPTY = (17000, "字典code不能为空!")
LOOKUP_NOT_EXIST = (17001, "字典不存在!")
LOOKUP_CODE_EXIST = (17002, "字典code已存在!")

TASK_NAME_EXIST = (18000, "定时任务名称以存在")
  • 自定义HTTP响应对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import typing

import orjson
from starlette import status
from starlette.responses import Response, JSONResponse

from web.backend.app.core import g
from web.backend.app.core.codeenum import CodeEnum
from web.backend.app.core.serialize import default_serialize


class ORJSONResponse(JSONResponse):
media_type = "application/json"

def render(self, content: typing.Any) -> bytes:
return orjson.dumps(
content,
option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
default=default_serialize
)


def partner_success(
data=None, headers=None,
code=CodeEnum.PARTNER_CODE_OK.code,
msg=CodeEnum.PARTNER_CODE_OK.msg,
http_code=status.HTTP_200_OK,
):
if data is None:
data = {}
success = True if code == CodeEnum.PARTNER_CODE_OK.code else False
content = dict(code=code, msg=msg, data=data, success=success, trace_id=g.trace_id)
return ORJSONResponse(status_code=http_code, content=content, headers=headers or {})


def resp_200(*, data: typing.Any = "", msg: str = "HTTP_200_OK") -> dict:
return {"code": 200, "data": data, "msg": msg}


def resp_400(*, data: str = None, msg: str = "HTTP_400_BAD_REQUEST") -> Response:
return ORJSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"code": 400, "data": data, "msg": msg})


def resp_401(*, data: str = None, msg: str = "HTTP_401_UNAUTHORIZED") -> Response:
return ORJSONResponse(status_code=status.HTTP_401_UNAUTHORIZED, content={"code": 401, "data": data, "msg": msg})


def resp_403(*, data: str = None, msg: str = "HTTP_403_FORBIDDEN") -> Response:
return ORJSONResponse(status_code=status.HTTP_403_FORBIDDEN, content={"code": 403, "data": data, "msg": msg})


def resp_404(*, data: str = None, msg: str = "HTTP_404_NOT_FOUND") -> Response:
return ORJSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"code": 404, "data": data, "msg": msg})


def resp_408(*, data: str = None, msg: str = "HTTP_408_REQUEST_TIMEOUT") -> Response:
return ORJSONResponse(status_code=status.HTTP_408_REQUEST_TIMEOUT, content={"code": 408, "data": data, "msg": msg})


def resp_429(*, data: str = None, msg: str = "HTTP_429_TOO_MANY_REQUESTS") -> Response:
return ORJSONResponse(status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"code": 429, "data": data, "msg": msg})


def resp_500(*, data: str = None, msg: typing.Union[list, dict, str] = "HTTP_500_INTERNAL_SERVER_ERROR") -> Response:
return ORJSONResponse(headers={"Access-Control-Allow-Origin": "*"},
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"code": 500, "data": data, "msg": msg})


def resp_501(*, data: str = None, msg: str = "HTTP_501_NOT_IMPLEMENTED") -> Response:
return ORJSONResponse(status_code=status.HTTP_501_NOT_IMPLEMENTED, content={"code": 501, "data": data, "msg": msg})


def resp_502(*, data: str = None, msg: str = "HTTP_502_BAD_GATEWAY") -> Response:
return ORJSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={"code": 502, "data": data, "msg": msg})
  • 全局代理对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import typing
from contextvars import ContextVar


class LocalProxy(object):
__slots__ = ("_storage",)

def __init__(self) -> None:
object.__setattr__(self, "_storage", ContextVar("local_storage"))

def __iter__(self) -> typing.Iterator[typing.Tuple[int, typing.Any]]:
return iter(self._storage.get({}).items())

def __release_local__(self) -> None:
self._storage.set({})

def __getattr__(self, name: str) -> typing.Any:
values = self._storage.get({})
try:
return values[name]
except KeyError:
return None

def __setattr__(self, name: str, value: typing.Any) -> None:
values = self._storage.get({}).copy()
values[name] = value
self._storage.set(values)

def __delattr__(self, name: str) -> None:
values = self._storage.get({}).copy()
try:
del values[name]
self._storage.set(values)
except KeyError:
...


g = LocalProxy()
  • 序列化工具
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import typing
from datetime import datetime

from fastapi.encoders import jsonable_encoder
from sqlalchemy import Row
from sqlalchemy.orm import DeclarativeMeta


def unwrap_scalars(items: typing.Union[typing.Sequence[Row], Row]) \
-> typing.Union[typing.List[typing.Dict[typing.Text, typing.Any]], typing.Dict[str, typing.Any]]:
if isinstance(items, typing.Iterable) and not isinstance(items, Row):
return [default_serialize(item) for item in items]
return default_serialize(items)


def default_serialize(obj):
try:
if isinstance(obj, datetime):
return obj.strftime("%Y-%m-%d %H:%M:%S")
if isinstance(obj, Row):
data = dict(zip(obj._fields, obj._data))
return {key: default_serialize(value) for key, value in data.items()}
if hasattr(obj, "__class__") and isinstance(obj.__class__, DeclarativeMeta):
return {c.name: default_serialize(getattr(obj, c.name)) for c in obj.__table__.columns}
except TypeError:
return repr(obj)
return jsonable_encoder(obj)
2.3初始化模块
  • 允许跨域
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware

from web.backend.config import config


def init_cors(app: FastAPI):
app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in config.CORS_ORIGINS],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
  • 数据库连接
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import json
import typing
import functools

from aioredis import Redis
from asyncio import current_task

from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import create_async_engine, \
AsyncSession, async_scoped_session, async_sessionmaker
from web.backend.config import config


class AdvancedRedis(Redis):
async def get(self, name: str) -> typing.Any:
data = await super(AdvancedRedis, self).get(name)
return json.loads(data) if data else None

async def set(
self,
name: str,
value: typing.Any,
ex=None,
px=None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
) -> typing.Any:
return await super(AdvancedRedis, self).set(name, json.dumps(value), ex=ex)


async def init_redis_pool() -> AdvancedRedis:
redis = await AdvancedRedis.from_url(
url=config.REDIS_URI,
encoding=config.GLOBAL_ENCODING,
decode_responses=True
)
return redis


db_engine = create_async_engine(
url=config.DATABASE_URI,
echo=config.DATABASE_ECHO,
pool_pre_ping=True,
pool_recycle=60 * 60 * 2
)

async_session_factory = async_sessionmaker(
bind=db_engine,
class_=AsyncSession,
autoflush=False,
autocommit=False,
expire_on_commit=False
)

async_session = async_scoped_session(async_session_factory, scopefunc=current_task)


def database_session(func: typing.Callable):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
arg_session = "session"
func_params = func.__code__.co_varnames
session_in_args = arg_session in func_params and func_params.index(arg_session) < len(args)
session_in_kwargs = arg_session in kwargs
if session_in_kwargs or session_in_args:
return await func(*args, **kwargs)
else:
async with async_session() as session:
fs = functools.partial(func, session=session, *args, **kwargs)
try:
return await fs()
except IntegrityError:
await session.rollback()
raise
finally:
await session.commit()
await async_session.remove()

return wrapper
  • 自定义异常类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import typing
import traceback

from loguru import logger
from fastapi import FastAPI, Request

from web.backend.app.core.codeenum import CodeEnum
from web.backend.app.core.httpresp import resp_400, resp_404, partner_success


class APIBaseException(Exception):
def __init__(self, err_or_code: typing.Union[CodeEnum, str]):
if isinstance(err_or_code, CodeEnum):
code = err_or_code.code
msg = err_or_code.msg
else:
code = CodeEnum.PARTNER_CODE_FAIL.code
msg = err_or_code
self.code = code
self.msg = msg

def __str__(self):
return f"{self.code}:{self.msg}"

def __repr__(self):
return f"{self.code}:{self.msg}"


class IpError(APIBaseException):
def __init__(self):
super(IpError, self).__init__("ip error")


class UserNotExist(APIBaseException):
def __init__(self):
super(UserNotExist, self).__init__("user not exist")


class AccessTokenFail(APIBaseException):
def __init__(self):
super(AccessTokenFail, self).__init__(CodeEnum.PARTNER_CODE_TOKEN_EXPIRED_FAIL)


class ParameterError(APIBaseException):
def __init__(self, err_code: typing.Union[CodeEnum, str]):
super(ParameterError, self).__init__(err_code)


def init_exception(app: FastAPI):
@app.exception_handler(IpError)
async def ip_error_handler(request: Request, exc: IpError):
logger.warning(f"{exc.msg}:{exc.code}\nURL:{request.method}-{request.url}\nHeaders:{request.headers}")
return resp_400(msg=exc.msg)

@app.exception_handler(UserNotExist)
async def user_not_exist_handler(request: Request, exc: UserNotExist):
logger.warning(f"{exc.msg}:{exc.code}\nURL:{request.method}-{request.url}\nHeaders:{request.headers}")
return resp_404(msg=exc.msg)

@app.exception_handler(ParameterError)
async def all_exception_handler(request: Request, exc: ParameterError):
logger.error(f"Params Error\n{request.method}URL:{request.url}\n"
f"Headers:{request.headers}\n{traceback.format_exc()}")
return partner_success(code=exc.code, msg=exc.msg)

@app.exception_handler(Exception)
async def all_exception_handler(request: Request, exc: Exception):
logger.error(f"Global Error\n{request.method} URL:{request.url}\n"
f"Headers:{request.headers}\n{traceback.format_exc()}")
return partner_success(code=CodeEnum.PARTNER_CODE_FAIL.code, msg=str(exc),
headers={"Access-Control-Allow-Origin": "*"})
  • 全局请求对象
1
2
3
4
5
6
from fastapi import Request
from web.backend.app.core import g


async def init_global_request(request: Request):
g.request = request
  • 日志
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import sys
import uuid
import logging
from pathlib import Path

from loguru import logger
from web.backend.config import config
from web.backend.app.core import g


def logger_file() -> str:
def create_dir(filename: str) -> Path:
path = Path(filename).absolute().parent / filename
not Path(path).exists() and Path.mkdir(path)
return path

log_path = create_dir(config.LOGGER_DIR)
files = os.listdir(log_path)
len(files) > 3 and os.remove(os.path.join(log_path, files[0]))
return os.path.join(log_path, config.LOGGER_NAME)


def correlation_id_filter(record):
def get_str_uuid():
return str(uuid.uuid4()).replace("-", "")

if not g.trace_id:
g.trace_id = get_str_uuid()
record["trace_id"] = g.trace_id
return record


fmt = "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green>| {thread} " \
"| <level>{level: <8}</level> | <yellow> {trace_id} </yellow> " \
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> " \
"| <level>{message}</level>"
logger.remove()
logger.add(
sys.stdout, level=config.LOGGER_LEVEL, colorize=True,
filter=correlation_id_filter, format=fmt
)


class InterceptHandler(logging.Handler):
def emit(self, record):
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
logger_opt = logger.opt(depth=6, exception=record.exc_info)
logger_opt.log(level, record.getMessage())


def init_logger():
logger_name_list = [name for name in logging.root.manager.loggerDict]

for logger_name in logger_name_list:
effective_level = logging.getLogger(logger_name).getEffectiveLevel()
if effective_level < logging.getLevelName(config.LOGGER_LEVEL.upper()):
logging.getLogger(logger_name).setLevel(config.LOGGER_LEVEL.upper())
if "." not in logger_name:
logging.getLogger(logger_name).handlers = []
logging.getLogger(logger_name).addHandler(InterceptHandler())
  • 中间件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from loguru import logger
from fastapi import FastAPI, Request

from web.backend.config import config
from web.backend.app.core import g
from web.backend.app.core.httpresp import partner_success
from web.backend.app.core.init.exception import AccessTokenFail
from web.backend.app.common.utils import get_str_uuid
from web.backend.app.common.constant import TEST_USER_INFO, CACHE_DAY


async def login_verification(request: Request):
token = request.headers.get("token", None)
router: str = request.scope.get('path', "")
if router.startswith("/api/v1") and not router.startswith(
"/api/v1/file") and router not in config.WHITE_ROUTER:
if not token:
raise AccessTokenFail()
user_info = await g.redis.get(TEST_USER_INFO.format(token))
if not user_info:
raise AccessTokenFail()
# reset token time
await g.redis.set(TEST_USER_INFO.format(token), user_info, CACHE_DAY)


def init_middleware(app: FastAPI):
@app.middleware("http")
async def intercept(request: Request, call_next):
g.trace_id = get_str_uuid()
token = request.headers.get("token", None)
g.redis, g.token = app.state.redis, token
remote_addr = request.headers.get("X-Real-IP", request.client.host)
logger.info(f"access_record->IP:{remote_addr},method:{request.method},url:{request.url}")
try:
await login_verification(request)
except AccessTokenFail as err:
return partner_success(code=err.code, msg=err.msg)
response = await call_next(request)
response.headers["X-request-id"] = g.trace_id
return response
  • 路由
1
2
3
4
5
6
7
8
from fastapi import FastAPI

from web.backend.config import config
from web.backend.app.apiv1.router import api_router


def init_router(app: FastAPI):
app.include_router(api_router, prefix=config.API_PREFIX)
2.4Model类模块
  • BaseModel
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import typing

from sqlalchemy import Column, DateTime, Integer, func, select, \
update, delete, insert, Select, Executable, Result, ClauseList, String, Boolean
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.declarative import as_declarative
from web.backend.app.core import g
from web.backend.app.core.init.database import database_session
from web.backend.app.core.init.exception import AccessTokenFail
from web.backend.app.core.serialize import unwrap_scalars
from web.backend.app.common.utils import current_user


@as_declarative()
class BaseModel(object):
__tablename__: str
__table_args__ = {"mysql_charset": "utf8"}
__mapper_args__ = {"eager_defaults": True}

id = Column(Integer(), nullable=False, primary_key=True, autoincrement=True)
created_on = Column(DateTime(), default=func.now(), comment="createTime")
updated_on = Column(DateTime(), default=func.now(), onupdate=func.now(), nullable=False, comment="updateTime")
created_by = Column(Integer, nullable=True, comment="creator_id")
updated_by = Column(Integer, nullable=True, comment="updater_id")
trace_id = Column(String(255), nullable=False, comment="trace_id")
enabled_flag = Column(Boolean(), default=1, nullable=False, comment="is_deleted")

@classmethod
async def get(cls, _id: typing.Union[int, str], to_dict=False) -> typing.Union["BaseModel", typing.Dict]:
sql = select(cls).where(cls.id == _id)
result = await cls.execute(sql)
data = result.scalar()
return data if not to_dict else unwrap_scalars(data)

@classmethod
async def create(cls, params: typing.Dict, to_dict: bool = False) -> typing.Union["BaseModel", typing.Dict]:
if not isinstance(params, dict):
raise ValueError("insert params error")
params = {key: value for key, value in params.items() if hasattr(cls, key)}
params = await cls.handle_params(params)
stmt = insert(cls).values(**params)
result = await cls.execute(stmt)
(primary_key,) = result.inserted_primary_key
params["id"] = primary_key
return cls(**params) if not to_dict else params

@classmethod
async def upsert(cls, params: typing.Union[typing.Dict]) -> typing.Dict[typing.Text, typing.Any]:
if not isinstance(params, dict):
raise ValueError("upsert params error")
params = {key: value for key, value in params.items() if hasattr(cls, key)}
params = await cls.handle_params(params)
_id = params.get("id", None)
if _id:
stmt = update(cls).where(cls.id == _id).values(**params)
else:
stmt = insert(cls).values(**params)
result = await cls.execute(stmt)
if result.is_insert:
(primary_key,) = result.inserted_primary_key
params["id"] = primary_key
return params

@classmethod
async def delete(cls, _id: typing.Union[int, str], _hard: bool = False) -> int:
if _hard is False:
stmt = update(cls).where(cls.id == _id).values(enabled_flag=0)
else:
stmt = delete(cls).where(cls.id == _id)
result = await cls.execute(stmt)
return result.rowcount

@classmethod
async def get_all(cls) -> typing.Optional[typing.Any]:
stmt = select(cls.get_columns())
return await cls.get_result(stmt)

@classmethod
@database_session
async def execute(cls, stmt: Executable, params: typing.Any = None, session: AsyncSession = None) \
-> Result[typing.Any]:
return await session.execute(stmt, params)

@classmethod
def get_columns(cls, exclude: set = None) -> ClauseList:
exclude = exclude if exclude else {}
return ClauseList(*[i for i in cls.__table__.columns if i.name not in exclude])

@classmethod
async def get_result(cls, stmt: Select, first=False) -> typing.Any:
result = await cls.execute(stmt)
data = result.first() if first else result.fetchall()
return unwrap_scalars(data) if data else None

@classmethod
async def handle_params(cls, params: typing.Any) -> typing.Any:
if isinstance(params, dict):
params = {key: value for key, value in params.items() if hasattr(cls, key)}
updated_by = params.get("updated_by", None)
created_by = params.get("created_by", None)
params["trace_id"] = g.trace_id
try:
user = await current_user()
except AccessTokenFail:
user = None
if user:
user_id = user.get("id", None)
params["updated_by"] = user_id
if not params.get("id", None):
params["created_by"] = user_id
else:
if not updated_by:
params["updated_by"] = 0
if not created_by:
params["created_by"] = 0
elif isinstance(params, list):
params = [await cls.handle_params(p) for p in params]
return params
  • UserModel
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from sqlalchemy import Column, String, Text, Integer, JSON

from sqlalchemy import select
from web.backend.app.model.base import BaseModel


class User(BaseModel):
__tablename__ = "t_user"
username = Column(String(64), nullable=False, comment="username", index=True)
password = Column(Text, nullable=False, comment="password")
email = Column(String(64), nullable=True, comment="email")
permission = Column(Integer, nullable=False, comment="permission")
status = Column(Integer, nullable=False, comment="user status", default=0)
remarks = Column(String(255), nullable=True, comment="desc for user")
avatar = Column(Text, nullable=True, comment="icon")
tags = Column(JSON, nullable=True, comment="tags")

@classmethod
async def get_user_by_name(cls, username: str):
stmt = select(*cls.get_columns()).where(cls.username == username)
return await cls.get_result(stmt, True)

@classmethod
async def list(cls):
stmt = select(*cls.get_columns()).order_by(cls.id.desc())
return await cls.get_result(stmt)
2.5Schema类模块
  • UserSchema
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import typing

from pydantic import BaseModel, Field
from web.backend.app.common.encrypto import decrypto_rsa_password


class UserLogin(BaseModel):
username: str = Field(..., description="username")
password: str = Field(..., description="password")


class UserQuery(BaseModel):
username: str = Field(None, description="username")


class UserCreate(BaseModel):
id: int = Field(None, title="id", description="id")
username: str = Field(..., title="username", description="username")
password: str = Field(description="password", default=decrypto_rsa_password("zer0py2c"))
email: str = Field(None, description="email")
permission: int = Field(None, description="permission")
remarks: str = Field(None, description="remarks")
avatar: str = Field(None, description="avatar")
tags: typing.List = Field(None, description="tags")


class UserResetPassword(BaseModel):
id: int = Field(..., description="id")
old_pwd: str = Field(..., description="old_pwd")
new_pwd: str = Field(..., description="new_pwd")
re_new_pwd: str = Field(..., description="re_new_pwd")


class UserDelete(BaseModel):
id: int = Field(..., title="id", description="id")
2.6服务模块
  • UserService
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import uuid
import typing
import traceback

from datetime import datetime
from loguru import logger
from web.backend.app.core import g
from web.backend.app.core.codeenum import CodeEnum
from web.backend.app.core.serialize import default_serialize
from web.backend.app.model.user import User
from web.backend.app.schema.user import UserLogin, UserCreate, \
UserResetPassword, UserDelete
from web.backend.app.common.constant import TEST_USER_INFO, CACHE_DAY
from web.backend.app.common.encrypto import decrypto_rsa_password, \
encrypto_rsa_password


class UserService(object):
@staticmethod
async def login(params: UserLogin) -> typing.Dict[typing.Text, typing.Any]:
username = params.username
password = params.password
if not username and not password:
raise ValueError(CodeEnum.PARTNER_CODE_PARAMS_FAIL.msg)
user = await User.get_user_by_name(username)
if not user:
raise ValueError(CodeEnum.WRONG_USER_NAME_OR_PASSWORD.msg)

if decrypto_rsa_password(user["password"]) != password:
raise ValueError(CodeEnum.WRONG_USER_NAME_OR_PASSWORD.msg)
token = str(uuid.uuid4())
login_time = default_serialize(datetime.now())
tags = user.get("tags", None)
permission = user.get("permission")
token_user = {
"id": user["id"], "token": token,
"login_time": login_time, "username": username,
"permission": "super" if permission > 0 else "normal",
"tags": tags if tags else []
}
await g.redis.set(TEST_USER_INFO.format(token), token_user, CACHE_DAY)
ip = g.request.headers.get("X-Real-IP", None)
if not ip:
ip = g.request.client.host
logger.info("Login success! username: {}, ip: {}".format(username, ip))
return token_user

@staticmethod
async def logout():
token = g.request.headers.get("token", None)
try:
await g.redis.delete(TEST_USER_INFO.format(token))
except Exception:
logger.error(traceback.format_exc())

@staticmethod
async def register(params: UserCreate) -> "User":
user = await User.get_user_by_name(params.username)
if user:
raise ValueError(CodeEnum.USERNAME_OR_EMAIL_IS_REGISTER.msg)
return await User.create(params.dict())

@staticmethod
async def delete(params: UserDelete):
try:
return await User.delete(params.id)
except Exception:
logger.error(traceback.format_exc())

@staticmethod
async def verify(token: str) -> typing.Dict[typing.Text, typing.Any]:
user = await g.redis.get(TEST_USER_INFO.format(token))
if not user:
raise ValueError(CodeEnum.PARTNER_CODE_TOKEN_EXPIRED_FAIL.msg)

return {
"id": user.get("id", None),
"username": user.get("username", None)
}

@staticmethod
async def reset_password(params: UserResetPassword):
if params.new_pwd != params.re_new_pwd:
raise ValueError(CodeEnum.PASSWORD_TWICE_IS_NOT_AGREEMENT.msg)
user_info = await User.get(params.id)
pwd = decrypto_rsa_password(user_info.password)
if params.old_pwd != pwd:
raise ValueError(CodeEnum.OLD_PASSWORD_ERROR.msg)
if params.new_pwd == pwd:
raise ValueError(CodeEnum.NEW_PWD_NO_OLD_PWD_EQUAL.msg)
new_pwd = encrypto_rsa_password(params.new_pwd)
await User.update(**{"password": new_pwd, "id": params.id})

@staticmethod
async def get_user_by_token(token: str) -> \
typing.Union[typing.Dict[typing.Text, typing.Any], None]:
token_user = await g.redis.get(TEST_USER_INFO.format(token))
if not token_user:
raise ValueError(CodeEnum.PARTNER_CODE_TOKEN_EXPIRED_FAIL.msg)
user_info = await User.get(token_user.get("id"))
if not user_info:
raise ValueError(CodeEnum.PARTNER_CODE_TOKEN_EXPIRED_FAIL.msg)
return {
"id": user_info.id, "avatar": user_info.avatar,
"username": user_info.username, "nickname": user_info.nickname,
"permission": user_info.permission, "tags": user_info.tags,
"login_time": token_user.get("login_time", None)
}

@staticmethod
async def get_all_user():
_all = await User.list()
return _all
2.7API接口模块
  • UserAPI
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from fastapi import APIRouter, Request

from web.backend.app.core.httpresp import partner_success
from web.backend.app.schema.user import UserLogin, UserCreate, UserDelete
from web.backend.app.service.user import UserService

router = APIRouter()


@router.post("/login", description="login")
async def login(params: UserLogin):
data = await UserService.login(params)
return partner_success(data, msg="Login Success!")


@router.post("/logout", description="logout")
async def logout():
await UserService.logout()
return partner_success()


@router.post("/get_user_by_token", description="get_user_by_token")
async def get_user_by_token(request: Request):
token = request.headers.get("token", None)
user = await UserService.get_user_by_token(token)
return partner_success(user)


@router.post("/list", description="list")
async def get_user_by_name():
data = await UserService.get_all_user()
return partner_success(data)


@router.post("/register", description="register")
async def user_register(user: UserCreate):
data = await UserService.register(user)
return partner_success(data)


@router.post("/delete", description="delete")
async def delete(params: UserDelete):
data = await UserService.delete(params)
return partner_success(data)


@router.post("/verify", description="verify")
async def verify(request: Request):
token = request.headers.get("token", None)
user = await UserService.verify(token)
return partner_success(user)
  • 路由
1
2
3
4
5
6
from fastapi import APIRouter
from web.backend.app.apiv1.endpoint import user, file

api_router = APIRouter()

api_router.include_router(user.router, prefix="/user", tags=["user"])
2.8初始化数据库
1
2
3
4
5
6
7
8
from web.backend.app.core.init.database import db_engine
from web.backend.app.model.base import BaseModel


async def init_db():
async with db_engine.begin() as conn:
await conn.run_sync(BaseModel.metadata.drop_all)
await conn.run_sync(BaseModel.metadata.create_all)
2.9主程序
  • ENV变量
1
2
REDIS_URI=redis://127.0.0.1:6379/0
SQLITE3_DATABASE_URI=sqlite+aiosqlite:///./sql_app.db?check_same_thread=False
  • 系统配置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import typing

from pydantic import BaseSettings, AnyHttpUrl, Field

project_desc = """
maintainer: zer0py2c
"""


class Config(BaseSettings):
PROJECT_DESC: str = project_desc
PROJECT_VERSION: typing.Union[int, str] = 1.0
BASE_URL: AnyHttpUrl = "http://127.0.0.1:9999"
API_PREFIX: str = "/api/v1"
STATIC_DIR: str = "static"
GLOBAL_ENCODING: str = "utf8"
CORS_ORIGINS: typing.List[typing.Any] = ["*"]
WHITE_ROUTER = ["/api/v1/user/login"]
SECRET_KEY: str = "kPBDjVk0o3Y1wLxdODxBpjwEjo7-Euegg4kdnzFIRjc"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1
REDIS_URI: str = Field(..., env="REDIS_URI")
DATABASE_URI: str = Field(..., env="SQLITE3_DATABASE_URI") # SQLite3 async
# DATABASE_URI: str = "mysql+asyncmy://root:123456@localhost:3306/zer0py2c?charset=UTF8MB4" # MySQL async
# DATABASE_URI: str = "postgresql+asyncpg://postgres:123456@localhost:5432/postgres" # PostgreSQL async
DATABASE_ECHO: bool = False
LOGGER_DIR: str = "logs"
LOGGER_NAME: str = "web_api.log"
LOGGER_LEVEL: str = "INFO"
LOGGER_ROTATION: str = "10MB"
LOGGER_RETENTION: str = "7days"
BASEDIR: str = os.path.join(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
TEST_FILES_DIR: str = os.path.join(os.path.abspath(os.path.dirname(os.path.dirname(__file__))), "files")

class Config:
case_sensitive = True
env_file = ".env"


config = Config()
  • 主模块
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import uvicorn

from fastapi import FastAPI, Depends
from web.backend.config import config
from web.backend.app.core.init import *
# from web.backend.app.model import init_db

app = FastAPI(
title="fastapi_project_demo",
description=config.PROJECT_DESC,
version=config.PROJECT_VERSION,
dependencies=[Depends(init_global_request)]
)


def init_app():
init_logger()
init_exception(app)
init_router(app)
init_middleware(app)
init_cors(app)


@app.on_event("startup")
async def startup():
init_app()
# await init_db()
app.state.redis = await init_redis_pool()


@app.on_event("shutdown")
async def shutdown():
await app.state.redis.close()


if __name__ == "__main__":
uvicorn.run(app="main:app", host="127.0.0.1", port=9999, reload=True)