1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| import typing
from sqlalchemy import Column, DateTime, Integer, func, select, \ update, delete, insert, Select, Executable, Result, ClauseList, String, Boolean from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.declarative import as_declarative from web.backend.app.core import g from web.backend.app.core.init.database import database_session from web.backend.app.core.init.exception import AccessTokenFail from web.backend.app.core.serialize import unwrap_scalars from web.backend.app.common.utils import current_user
@as_declarative() class BaseModel(object): __tablename__: str __table_args__ = {"mysql_charset": "utf8"} __mapper_args__ = {"eager_defaults": True}
id = Column(Integer(), nullable=False, primary_key=True, autoincrement=True) created_on = Column(DateTime(), default=func.now(), comment="createTime") updated_on = Column(DateTime(), default=func.now(), onupdate=func.now(), nullable=False, comment="updateTime") created_by = Column(Integer, nullable=True, comment="creator_id") updated_by = Column(Integer, nullable=True, comment="updater_id") trace_id = Column(String(255), nullable=False, comment="trace_id") enabled_flag = Column(Boolean(), default=1, nullable=False, comment="is_deleted")
@classmethod async def get(cls, _id: typing.Union[int, str], to_dict=False) -> typing.Union["BaseModel", typing.Dict]: sql = select(cls).where(cls.id == _id) result = await cls.execute(sql) data = result.scalar() return data if not to_dict else unwrap_scalars(data)
@classmethod async def create(cls, params: typing.Dict, to_dict: bool = False) -> typing.Union["BaseModel", typing.Dict]: if not isinstance(params, dict): raise ValueError("insert params error") params = {key: value for key, value in params.items() if hasattr(cls, key)} params = await cls.handle_params(params) stmt = insert(cls).values(**params) result = await cls.execute(stmt) (primary_key,) = result.inserted_primary_key params["id"] = primary_key return cls(**params) if not to_dict else params
@classmethod async def upsert(cls, params: typing.Union[typing.Dict]) -> typing.Dict[typing.Text, typing.Any]: if not isinstance(params, dict): raise ValueError("upsert params error") params = {key: value for key, value in params.items() if hasattr(cls, key)} params = await cls.handle_params(params) _id = params.get("id", None) if _id: stmt = update(cls).where(cls.id == _id).values(**params) else: stmt = insert(cls).values(**params) result = await cls.execute(stmt) if result.is_insert: (primary_key,) = result.inserted_primary_key params["id"] = primary_key return params
@classmethod async def delete(cls, _id: typing.Union[int, str], _hard: bool = False) -> int: if _hard is False: stmt = update(cls).where(cls.id == _id).values(enabled_flag=0) else: stmt = delete(cls).where(cls.id == _id) result = await cls.execute(stmt) return result.rowcount
@classmethod async def get_all(cls) -> typing.Optional[typing.Any]: stmt = select(cls.get_columns()) return await cls.get_result(stmt)
@classmethod @database_session async def execute(cls, stmt: Executable, params: typing.Any = None, session: AsyncSession = None) \ -> Result[typing.Any]: return await session.execute(stmt, params)
@classmethod def get_columns(cls, exclude: set = None) -> ClauseList: exclude = exclude if exclude else {} return ClauseList(*[i for i in cls.__table__.columns if i.name not in exclude])
@classmethod async def get_result(cls, stmt: Select, first=False) -> typing.Any: result = await cls.execute(stmt) data = result.first() if first else result.fetchall() return unwrap_scalars(data) if data else None
@classmethod async def handle_params(cls, params: typing.Any) -> typing.Any: if isinstance(params, dict): params = {key: value for key, value in params.items() if hasattr(cls, key)} updated_by = params.get("updated_by", None) created_by = params.get("created_by", None) params["trace_id"] = g.trace_id try: user = await current_user() except AccessTokenFail: user = None if user: user_id = user.get("id", None) params["updated_by"] = user_id if not params.get("id", None): params["created_by"] = user_id else: if not updated_by: params["updated_by"] = 0 if not created_by: params["created_by"] = 0 elif isinstance(params, list): params = [await cls.handle_params(p) for p in params] return params
|