Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Day 03 #15

Open
wants to merge 8 commits into
base: day-03
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 15 additions & 28 deletions www/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import aiomysql

def log(sql, args=()):
logging.info('SQL: %s' % sql)
def log(sql, args=[]):
logging.info('SQL: [%s] args: %s' % (sql, args))

async def create_pool(loop, **kw):
logging.info('create database connection pool...')
Expand Down Expand Up @@ -40,7 +40,7 @@ async def select(sql, args, size=None):
return rs

async def execute(sql, args, autocommit=True):
log(sql)
log(sql, args)
async with __pool.get() as conn:
if not autocommit:
await conn.begin()
Expand All @@ -56,12 +56,6 @@ async def execute(sql, args, autocommit=True):
raise
return affected

def create_args_string(num):
L = []
for n in range(num):
L.append('?')
return ', '.join(L)

class Field(object):

def __init__(self, name, column_type, primary_key, default):
Expand Down Expand Up @@ -103,34 +97,32 @@ class ModelMetaclass(type):
def __new__(cls, name, bases, attrs):
if name=='Model':
return type.__new__(cls, name, bases, attrs)
tableName = attrs.get('__table__', None) or name
tableName = attrs.get('__table__', name)
logging.info('found model: %s (table: %s)' % (name, tableName))
mappings = dict()
fields = []
escaped_fields = []
primaryKey = None
for k, v in attrs.items():
for k, v in attrs.copy().items():
if isinstance(v, Field):
logging.info(' found mapping: %s ==> %s' % (k, v))
mappings[k] = v
mappings[k] = attrs.pop(k)
if v.primary_key:
# 找到主键:
if primaryKey:
raise StandardError('Duplicate primary key for field: %s' % k)
primaryKey = k
else:
fields.append(k)
escaped_fields.append(k)
if not primaryKey:
raise StandardError('Primary key not found.')
for k in mappings.keys():
attrs.pop(k)
escaped_fields = list(map(lambda f: '`%s`' % f, fields))

attrs['__mappings__'] = mappings # 保存属性和列的映射关系
attrs['__table__'] = tableName
attrs['__primary_key__'] = primaryKey # 主键属性名
attrs['__fields__'] = fields # 除主键外的属性名
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
attrs['__fields__'] = escaped_fields + [primaryKey] # 全部属性名,主键一定在是最后
attrs['__select__'] = 'select * from `%s`' % (tableName)
attrs['__insert__'] = 'insert into `%s` (%s) values (%s)' % (tableName, ', '.join('`%s`' % f for f in attrs['__fields__']), ', '.join('?' * len(mappings)))
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join('`%s`=?' % f for f in escaped_fields), primaryKey)
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
return type.__new__(cls, name, bases, attrs)

Expand All @@ -148,9 +140,6 @@ def __getattr__(self, key):
def __setattr__(self, key, value):
self[key] = value

def getValue(self, key):
return getattr(self, key, None)

def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None:
Expand Down Expand Up @@ -210,20 +199,18 @@ async def find(cls, pk):

async def save(self):
args = list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
rows = await execute(self.__insert__, args)
if rows != 1:
logging.warn('failed to insert record: affected rows: %s' % rows)

async def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
args = list(map(self.get, self.__fields__))
rows = await execute(self.__update__, args)
if rows != 1:
logging.warn('failed to update by primary key: affected rows: %s' % rows)

async def remove(self):
args = [self.getValue(self.__primary_key__)]
args = [self.get(self.__primary_key__)]
rows = await execute(self.__delete__, args)
if rows != 1:
logging.warn('failed to remove by primary key: affected rows: %s' % rows)