X Tutup
# encoding: utf-8 __author__ = 'zhanghe' from psycopg2 import * import json from datetime import date, datetime # 数据库日志专用配置 from log import Logger my_logger = Logger('postgres', 'postgres.log', 'DEBUG') my_logger.set_file_level('DEBUG') my_logger.set_stream_level('WARNING') # WARNING DEBUG my_logger.set_stream_handler_fmt('%(message)s') my_logger.load() logger = my_logger.logger # my_logger.get_memory_usage() class Postgres(object): """ 自定义Postgres工具 """ def __init__(self, db_config, db_name=None): self.db_config = db_config if db_name is not None: self.db_config['database'] = db_name try: self.conn = connect( database=self.db_config['database'], user=self.db_config['user'], password=self.db_config['password'], host=self.db_config['host'], port=self.db_config['port'] ) except Exception, e: logger.error(e) @staticmethod def __default(obj): """ 支持datetime的json encode TypeError: datetime.datetime(2015, 10, 21, 8, 42, 54) is not JSON serializable :param obj: :return: """ if isinstance(obj, datetime): return obj.strftime('%Y-%m-%d %H:%M:%S') elif isinstance(obj, date): return obj.strftime('%Y-%m-%d') else: raise TypeError('%r is not JSON serializable' % obj) def is_conn_open(self): """ 检测连接是否打开 :return: """ if self.conn is None or self.conn.closed == 1: return False else: return True def close_conn(self): """ 关闭数据库连接 :return: """ if self.is_conn_open() is True: self.conn.close() def truncate(self, table_name): """ 清空表 :param table_name: :return: """ if self.is_conn_open() is False: logger.error('连接已断开') return [] # 参数判断 if table_name is None: logger.error('查询表名缺少参数') return [] sql = 'truncate table %s' % table_name logger.info(sql) cursor = self.conn.cursor() try: cursor.execute(sql) self.conn.commit() logger.info('更新行数:%s' % cursor.rowcount) cursor.close() return True except Exception, e: logger.error(e) finally: cursor.close() def get_columns_name(self, table_name): """ 获取数据表的字段名称 :param table_name: :return: """ if self.is_conn_open() is False: logger.error('连接已断开') return [] # 参数判断 if table_name is None: logger.error('查询表名缺少参数') return [] cursor = self.conn.cursor() sql = "select column_name from information_schema.columns where table_name = '%s'" % table_name logger.info(sql) try: cursor.execute(sql) result = cursor.fetchall() row = [item[0] for item in result] return row except Exception, e: logger.error(e) finally: cursor.close() def get_row(self, table_name, condition=None): """ 获取单行数据 :return: """ if self.is_conn_open() is False: logger.error('连接已断开') return None # 参数判断 if table_name is None: logger.error('查询表名缺少参数') return None if condition and not isinstance(condition, list): logger.error('查询条件参数格式错误') return None # 组装查询条件 if condition: sql_condition = 'where ' sql_condition += ' and '.join(condition) else: sql_condition = '' sql = 'select * from %s %s limit 1' % (table_name, sql_condition) logger.info(sql) cursor = self.conn.cursor() try: cursor.execute(sql) row = cursor.fetchone() return row except Exception, e: logger.error(e) finally: cursor.close() def get_rows(self, table_name, condition=None, limit='limit 10 offset 0'): """ 获取多行数据 con_obj.get_rows('company', ["type='6'"], 'limit 10 offset 0') con_obj.get_rows('company', ["type='6'"], 'limit 10') """ if self.is_conn_open() is False: logger.error('连接已断开') return None # 参数判断 if table_name is None: logger.error('查询表名缺少参数') return None if condition and not isinstance(condition, list): logger.error('查询条件参数格式错误') return None # 组装查询条件 if condition: sql_condition = 'where ' sql_condition += ' and '.join(condition) else: sql_condition = '' sql = 'select * from %s %s %s' % (table_name, sql_condition, limit) logger.info(sql) cursor = self.conn.cursor() try: cursor.execute(sql) rows = cursor.fetchall() return rows except Exception, e: logger.error(e) finally: cursor.close() def get_count(self, table_name, condition=None): """ 获取记录总数 :return: """ if self.is_conn_open() is False: logger.error('连接已断开') return 0 # 参数判断 if table_name is None: logger.error('查询表名缺少参数') return 0 if condition and not isinstance(condition, list): logger.error('查询条件参数格式错误') return 0 # 组装查询条件 if condition: sql_condition = 'where ' sql_condition += ' and '.join(condition) else: sql_condition = '' sql = 'select count(*) from %s %s' % (table_name, sql_condition) logger.info(sql) cursor = self.conn.cursor() try: cursor.execute(sql) row = cursor.fetchone() count = row[0] return count except Exception, e: logger.error(e) finally: cursor.close() def output_row(self, table_name, condition=None, style=0): """ 格式化输出单个记录 style=0 键值对齐风格 style=1 JSON缩进风格 """ # 参数判断 if not table_name: logger.error('查询数据缺少参数') return None if condition and not isinstance(condition, list): logger.error('查询条件参数格式错误') return None columns_name = self.get_columns_name(table_name) row = self.get_row(table_name, condition) if not columns_name: logger.error('表名不存在') return None if not row: logger.error('记录不存在') return None if style == 0: # 获取字段名称最大的长度值作为缩进依据 max_len_column = max([len(each_column) for each_column in columns_name]) str_format = '{0: >%s}' % max_len_column columns_name = [str_format.format(each_column) for each_column in columns_name] result = dict(zip(columns_name, row)) print '********** 表名[%s] **********' % table_name for key, item in result.items(): print key, ':', item else: result = dict(zip(columns_name, row)) print json.dumps(result, indent=4, ensure_ascii=False, default=self.__default) def output_rows(self, table_name, condition=None, limit='limit 10 offset 0', style=0): """ 格式化输出批量记录 style=0 键值对齐风格 style=1 JSON缩进风格 """ # 参数判断 if not table_name: logger.error('查询数据缺少参数') return None if condition and not isinstance(condition, list): logger.error('查询条件参数格式错误') return None columns_name = self.get_columns_name(table_name) rows = self.get_rows(table_name, condition, limit) if not columns_name: logger.error('表名不存在') return None if not rows: logger.error('记录不存在') return None if style == 0: # 获取字段名称最大的长度值作为缩进依据 max_len_column = max([len(each_column) for each_column in columns_name]) str_format = '{0: >%s}' % max_len_column columns_name = [str_format.format(each_column) for each_column in columns_name] count = 0 total = len(rows) for row in rows: result = dict(zip(columns_name, row)) count += 1 print '********** 表名[%s] [%d/%d] **********' % (table_name, count, total) for key, item in result.items(): print key, ':', item else: for row in rows: result = dict(zip(columns_name, row)) print json.dumps(result, indent=4, ensure_ascii=False, default=self.__default) def update(self, table_name, update_field, condition=None): """ 更新数据 con_obj.update('company', ["title='标题'", "flag='2'"], ["type='6'"]) """ if self.is_conn_open() is False: logger.error('连接已断开') return False # 参数判断 if not table_name or not update_field: logger.error('更新数据缺少参数') return False if not isinstance(update_field, list) or (condition and not isinstance(condition, list)): logger.error('更新数据参数格式错误') return False # 组装更新字段 if update_field: sql_update_field = 'set ' sql_update_field += ' and '.join(update_field) else: sql_update_field = '' # 组装更新条件 if condition: sql_condition = 'where ' sql_condition += ' and '.join(condition) else: sql_condition = '' # 拼接sql语句 sql = 'update %s %s %s' % (table_name, sql_update_field, sql_condition) logger.info(sql) cursor = self.conn.cursor() try: cursor.execute(sql) self.conn.commit() logger.info('更新行数:%s' % cursor.rowcount) return True except Exception, e: logger.error(e) finally: cursor.close() def delete(self, table_name, condition=None): """ 删除数据 con_obj.delete('company', ["type='6'", "flag='2'"]) """ if self.is_conn_open() is False: logger.error('连接已断开') return False # 参数判断 if condition and not isinstance(condition, list): logger.error('删除数据参数格式错误') return False # 组装删除条件 if condition: sql_condition = 'where ' sql_condition += ' and '.join(condition) else: sql_condition = '' # 拼接sql语句 sql = 'delete from %s %s' % (table_name, sql_condition) logger.info(sql) cursor = self.conn.cursor() try: cursor.execute(sql) self.conn.commit() logger.info('删除行数:%s' % cursor.rowcount) logger.info('删除成功') return True except Exception, e: logger.error(e) finally: cursor.close() def query_by_sql(self, sql=None): """ 根据sql语句查询 :return: """ if self.is_conn_open() is False: logger.error('连接已断开') return None if sql is None: logger.error('sql语句不能为空') return None # 安全性校验 sql = sql.lower() logger.info(sql) if not sql.startswith('select'): logger.error('未授权的操作') return None cursor = self.conn.cursor() try: cursor.execute(sql) rows = cursor.fetchall() return rows except Exception, e: logger.error(e) finally: cursor.close() def update_by_sql(self, sql=None): """ 根据sql语句[增删改] :return: """ if self.is_conn_open() is False: logger.error('连接已断开') return False if sql is None: logger.error('sql语句不能为空') return False # 安全性校验 sql = sql.lower() logger.info(sql) if not (sql.startswith('update') or sql.startswith('insert') or sql.startswith('delete')): logger.error('未授权的操作') return False cursor = self.conn.cursor() try: cursor.execute(sql) self.conn.commit() logger.info('影响行数:%s' % cursor.rowcount) logger.info('执行成功') return True except Exception, e: logger.error(e) finally: cursor.close()
X Tutup