-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathgenerate.py
More file actions
124 lines (98 loc) · 4.23 KB
/
Copy pathgenerate.py
File metadata and controls
124 lines (98 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
import os
import shutil
import xml.etree.ElementTree as ET
from jinja2 import Environment, FileSystemLoader
MYSQL_TO_CPP = {
"INT": "int",
"BIGINT": "long long",
"VARCHAR": "std::string",
"TEXT": "std::string",
"FLOAT": "float",
"DOUBLE": "double",
"DATETIME": "std::chrono::system_clock::time_point",
}
CPP_SET_FUNC = {
"int": "Int",
"long long": "Int64",
"float": "Double",
"double": "Double",
"std::string": "String",
"std::chrono::system_clock::time_point": "String"
}
def map_cpp_type(mysql_type):
for key in MYSQL_TO_CPP:
if key in mysql_type:
return MYSQL_TO_CPP[key]
return "std::string"
def map_cpp_set_func(cpp_type):
return CPP_SET_FUNC.get(cpp_type, "String")
def parse_schema(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
tables = []
for table_node in root.findall('table'):
class_name = table_node.get("class_name") or table_node.get("name").capitalize() + "DAO"
vo_class_name = class_name.replace("DAO", "VO") if class_name.endswith("DAO") else table_node.get("name").capitalize() + "VO"
table = {
"name": table_node.get("name"),
"class_name": class_name,
"vo_class_name": vo_class_name,
"file_name": table_node.get("file_name") or table_node.get("name").lower() + "_dao",
"columns": [],
"primary_key": None,
"unique_keys": [],
"indexes": []
}
for col_node in table_node.findall('column'):
col = {
"name": col_node.get("name"),
"type": col_node.get("type"),
"auto_increment": col_node.get("auto_increment") == "true",
"nullable": col_node.get("nullable", "true") == "true",
"default": col_node.get("default"),
}
col["cpp_type"] = map_cpp_type(col["type"])
col["cpp_set_func"] = map_cpp_set_func(col["cpp_type"])
col["is_datetime"] = col["cpp_type"] == "std::chrono::system_clock::time_point"
table["columns"].append(col)
pk_node = table_node.find('primary_key')
if pk_node is not None:
pk_names = [name.strip() for name in pk_node.text.split(",")]
table["primary_key"] = [c for c in table["columns"] if c["name"] in pk_names]
for col in table["columns"]:
col["is_pk"] = col["name"] in pk_names
else:
table["primary_key"] = []
for col in table["columns"]:
col["is_pk"] = False
for unique_key_node in table_node.findall('unique_key'):
unique_key_columns = unique_key_node.text.split(",")
table["unique_keys"].append([col.strip() for col in unique_key_columns])
for index_node in table_node.findall('index'):
index_columns = index_node.text.split(",")
table["indexes"].append([col.strip() for col in index_columns])
tables.append(table)
return tables
def render_templates(tables):
env = Environment(loader=FileSystemLoader("templates"), trim_blocks=True, lstrip_blocks=True)
pathname = "../Engine/SQL/generated"
if os.path.exists(pathname):
shutil.rmtree(pathname)
os.makedirs(pathname, exist_ok=True)
create_sqls = []
is_first_file = True
for table in tables:
with open(f"{pathname}/vo.h", "a") as f:
f.write(env.get_template("vo.h.j2").render(table=table, include_header=is_first_file))
with open(f"{pathname}/{table['file_name']}.h", "a") as f:
f.write(env.get_template("dao.h.j2").render(class_name=table["class_name"], table=table, include_header=is_first_file))
with open(f"{pathname}/{table['file_name']}.cpp", "a") as f:
f.write(env.get_template("dao.cpp.j2").render(class_name=table["class_name"], table=table, include_header=is_first_file))
create_sqls.append(env.get_template("create_table.sql.j2").render(table=table))
is_first_file = False
with open(f"{pathname}/create_tables.sql", "w") as f:
f.write("\n\n".join(create_sqls))
if __name__ == "__main__":
tables = parse_schema("schema.xml")
render_templates(tables)