Python读取CSV并导入数据库

Python 读取CSV并导入数据库

日常工作中,经常遇到需要将csv数据导入数据库的需求,因此写了这样一个简单的导入工具

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
'''
CSV数据导入工具
author: 韩进祥
'''
import json
from db.Query import Query
import copy

class ImportData():
def __init__(self,filepath):
# 每次导入的数据行数
self.max_row = 1000
# 忽略前多少行
self.ignore_num = 1
# 导入表字段名称 依次按照数据内容排列 按照实际内容填写,可以不是数据库表中拥有的字段
self.table_fields = [

]
self.encoding = "ansi"
# 字段分隔符
self.split = ","
# 需要去除的行尾字符
self.rstrip = "\n"
self.filepath = filepath
self.db = Query()
self.table = ""

def readCsv(self):
'''
读取CSV文件
'''
datas = []
a = range(10)
print(a)
with open(self.filepath,"r",encoding=self.encoding) as f:
if self.ignore_num != 0:
for x in range(0,self.ignore_num):
print("第%d行被忽略" % (x+1))
line = f.readline()
line = f.readline()
print(line)
i = 0
while line:
# 获取一行数据
row_data = line.rstrip(self.rstrip)
row_data = row_data.split(self.split)
datas.append(row_data)
line = f.readline()
i = i + 1
print("读取到%d条数据" % i)
return datas

def formatData(self,datas):
'''
格式化数据
'''
# 获取数据表的字段列表
fields = self.table_fields
pretreatmentDatas = []
for item in datas:
row_data = dict(zip(fields,item))
pretreatmentDatas.append(row_data)
fields = self.db.getTableFields("nawesm_pay_activity_trade_temp")
# 筛选出不需要导入的字段
diff_fields = list(set(self.table_fields).difference(set(fields)) )
# 筛选仅将数据库中有字段内容导入数据库
for item in pretreatmentDatas:
for i in diff_fields:
del item[i]
# print(pretreatmentDatas)
formatDatas = pretreatmentDatas
return formatDatas

# def
def importData(self,datas):
'''
执行导入
'''
res = self.db.insertAll(self.table,datas,5000)
return res



if __name__ == "__main__":
# 指定要导入的文件路径
filepath = ""
import_data = ImportData(filepath)
datas = import_data.readCsv()
formatDatas = import_data.formatData(datas)
res = import_data.importData(formatDatas)
print("执行完毕,成功插入%d条数据,理论应插入%d条数据" % res)

代码中出现的db是我基于pymysql封装的一个python包,是一个半成品,核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# 数据库操作类二次封装
import pymysql
import datetime
# import os
from pytz import timezone
from pathlib import Path
class Query():
# 初始化创建连接
def __init__(self,config = {},is_init_config=True):
# 是否将现有配置与初始配置合并
if is_init_config == True:
from . import dbconfig
config = dbconfig.config
self.config = dict(dbconfig.config,**config)
else:
self.config = config
self.conn = self.__connect(config)

# 创建数据库连接
def __connect(self,config={},):

# port = '3306'
conn = pymysql.connect(
host = self.config['host'],
user = self.config['user'],
passwd = self.config['passwd'],
db = self.config['db'],
)
return conn

# 执行预格式化查询相关的sql语句 返回游标
def query(self,sql,param=()):
cursor = self.conn.cursor()
# 是否执行参数绑定
try:
if param == ():
cursor.execute(sql)
else:
cursor.execute(sql,param)
msg = "执行成功"
status = 1
except Exception as e:
msg = str(e)
status = 0
self.__wirteLog(sql,status,msg)
return cursor

# 执行插入,更新相关SQL语句,返回受影响的行数
def save(self,sql):
cursor = self.conn.cursor()
try:
row = cursor.execute(sql)
self.conn.commit()
msg = "本次提交影响%d条数据" % row
status = 1
except Exception as e:
# 发生错误时回滚
self.conn.rollback()
msg = str(e)
status = 0
row = 0
self.__wirteLog(sql,status,msg)
return row

# 执行SQL select查询 返回所有结果集
def select(self,sql,param=()):
# sql = ""
cursor = self.query(sql,param)
datas = cursor.fetchall()
return datas


def count(self,table,where={},field="*"):
'''
* COUNT查询
* @access public
* @param string table 表名
* @param dict or string where 查询条件
* @param string field 字段名
* @return integer|string
'''
if type(where) == "dict":
# where = ",".join(where)
pass
# 字典条件处理

if where == "":
where_str = ""
else:
where_str = " WHERE %s " % where
sql = "SELECT count(%s) FROM %s %s" % (field,table,where_str)
cursor = self.query(sql)
datas = cursor.fetchone()
# print(datas[0])
count = datas[0]
return count

# 插入一维字典数据到数据库 字典的键就是字段名
def insert(self,table,data):
"""
插入一维字典数据到数据库 , 一次只能插入一行
@param
data [dict] 二维字典
@return
row [int] 插入成功行数
"""
# 获取需要插入的字段名和字段值
field_str = self.dictKeyInsertFormat(data)
value_str = self.dictValueInsertFormat(data)

# 组合sql语句
sql = "INSERT INTO `%s`%s VALUE %s" % (table,field_str,value_str)
# 提交执行SQL语句
row = self.save(sql)
return row

# 批量插入二维字典数据到数据库
def insertAll(self,table,data,max_num=1000):
"""
插入多行数据到数据库
@param
data [dict] 二维字典
数据格式:
data = [
{···},
{···},
{···},
]
@return
row [int] 插入成功行数
i [int] 理论应插入行数
"""

# 获取需要插入的字段名和字段值
# try :
field_str = self.dictKeyInsertFormat(data[0])

# 需要写入的部分整体
values_str = ""
# 成功导入行数
row = 0
# 理论导入行数
i = 0
# 每次提交行数最大值判断基数
n = 1
lenth = len(data)
for item in data:
# item = data[key]
i = i + 1
# 构造每一行数据的sql部分
value_str = self.dictValueInsertFormat(item)
# 清除最后一个,
values_str = values_str + ("%s," % (value_str))

# 判断是否到达最大长度
if (n == lenth) or (n == max_num):
values_str = values_str.strip(",")
sql = "INSERT INTO `%s`%s VALUES %s" % (table,field_str,values_str)
values_str = ""
lenth = lenth - max_num
row = row + self.save(sql)
n = 1
else:
n = n + 1
return row,i

# print(values_str)
# 字典形式条件转为字符串
def whereToStr(self,where):
pass
for key,value in where.items():

print(key)
print(value)

# 获取数据库表字段列表
def getTableFields(self,table_name):
sql = "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.Columns WHERE TABLE_NAME = %s AND table_schema = %s"
res = self.select(sql,(table_name,self.config['db']))
fields = []
for item in res:
fields.append(item[0])
return fields
# pass

# 析构函数 关闭数据库,释放数据库资源
def __del__(self):
self.conn.close()

# 参数转义防止SQL注入
@staticmethod
def strFormat(str,format_str_dict={}):
"""
对输入字符串进行敏感字符转义,防止sql注入及破坏sql
"""

# 预定义需要转义的字符
format_init_dict = {
"\\" : "\\\\",
"'" : "\\'",
'"' : '\\"',
';' : '\\;',
}
format_dict = dict(format_init_dict,**format_str_dict)
for key,value in format_dict.items():
str = str.replace(key,value)

return str

# 字典值格式化为sql insert 语句所需形式
@classmethod
def dictValueInsertFormat(self,data):
"""
将输入的字典值格式化为sql insert 语句所需形式
"""
# value_list = data.values()
value_str = ""
for key in data:
value = data[key]
if isinstance(value,str):
# 关键字符转义
value = self.strFormat(value)
value_str = value_str + ('"%s",' % value)
elif isinstance(value,int):
value_str = value_str + ('%d,' % value)
elif isinstance(value,float):
value_str = value_str + ("%f," % value)
elif value is None:
value_str = value_str + ("%s," % ('NULL'))
else:
value = str(value)
value = self.strFormat(value)
value_str = value_str + ('"%s",' % value)
value_str = "(%s)" % value_str.strip(",")
return value_str

@classmethod
def dictKeyInsertFormat(self,data):
"""
对输入的dict键格式化为insert语句 字段部分所需格式
"""
key_list = data.keys()
field_str = ""
key_list = data.keys()
field_str = "(`" + ("`,`").join(key_list) + "`)"
return field_str

# sql执行日志记录方法
def __wirteLog(self,sql,status,msg):
"""
sql执行日志写入
@param
sql [str] 需要记录的sql语句
status [int] 数据库执行的状态
msg [str] 执行结果或错误信息
@return
None 该方法没有返回值
"""

# 判断是否开启日志写入功能
if self.config['sql_log'] == True:
cst_tz = timezone(self.config['timezone'])
time = datetime.datetime.now(cst_tz)
# 获取时间
str_year_month = time.strftime("%Y%m")
str_day = time.strftime("%d")
str_time = time.strftime("%Y-%m-%d %H:%M:%S")
filepath = self.config['sql_log_path'] + "/" + str_year_month + "/"
# 判断目录是否存在
p = Path(filepath)
if not p.exists():
p.mkdir(exist_ok=True,parents=True)
filename = filepath + str_day + ".log"

try:
with open(filename,"a+", encoding='utf-8') as sqlfile:
sqlfile.write("---------------------------------------------------------------------\n")
# 写入时间
sqlfile.write("[ time ] " + str_time + "\n")
sqlfile.write("[ host ] " + self.config['host'] + "\n")
sqlfile.write("[ dbname ] " + self.config['db'] + "\n")
# sql语句
sqlfile.write("[ sql ] " + sql + "\n")
# 写入执行结果或错误信息
if status == 1:
sqlfile.write("[ result ] " + msg + "\n")
else:
sqlfile.write("[ error ] " + msg + "\n")
except Exception as e:
print("日志操作失败:%s" % str(e))

该类库的数据库连接配置文件默认为dbconfig.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
config = {
# 主机地址
"host" : '',
# 用户名
"user" : '',
# 连接端口
"port" : 3306,
# 密码
"passwd" : '',
# 数据库名
"db" : '',
# sql_log 为True时,执行的sql语句将被记录在日志文件中
"sql_log" : True,
# 日志保存路径
"sql_log_path" : "./log/",
# 默认时区
"timezone" : 'Asia/Shanghai'
}