diff --git a/.gitignore b/.gitignore index cfc5d5e..7701a64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .vscode/* */__pycache__/* +__pycache__ *.pyc __local/* *.db @@ -11,3 +12,4 @@ MercurySQL.egg-info/* runtime runtime/* *.test.py + diff --git a/MercurySQL/__init__.py b/MercurySQL/__init__.py index ecf7ca0..bdfbf62 100644 --- a/MercurySQL/__init__.py +++ b/MercurySQL/__init__.py @@ -17,3 +17,10 @@ from .gensql import DataBase, Table, set_driver from . import drivers + +class SQL: + """Wrap everything together""" + DataBase = DataBase + Table = Table + set_driver = set_driver + drivers = drivers diff --git a/MercurySQL/drivers/mysql.py b/MercurySQL/drivers/mysql_driver.py similarity index 97% rename from MercurySQL/drivers/mysql.py rename to MercurySQL/drivers/mysql_driver.py index 6dd9b6b..cc84eba 100644 --- a/MercurySQL/drivers/mysql.py +++ b/MercurySQL/drivers/mysql_driver.py @@ -190,7 +190,7 @@ def parse(type_: Any) -> str: return res @staticmethod - def connect(db_name: str, host: str, user: str, passwd: str = '', force=False) -> Conn: + def connect(db_name: str, host: str, user: str, password: str = '', force=False) -> Conn: """ Connect to a MySQL database. """ @@ -198,14 +198,14 @@ def connect(db_name: str, host: str, user: str, passwd: str = '', force=False) - return mysql.connector.connect( host=host, user=user, - passwd=passwd, + passwd=password, database=db_name ) else: conn = mysql.connector.connect( host=host, user=user, - passwd=passwd + passwd=password ) conn.backup_cursor = conn.cursor diff --git a/MercurySQL/drivers/sqlite.py b/MercurySQL/drivers/sqlite.py index 548d812..8190d55 100644 --- a/MercurySQL/drivers/sqlite.py +++ b/MercurySQL/drivers/sqlite.py @@ -35,15 +35,15 @@ def get_all_columns(table_name: str) -> str: return f"PRAGMA table_info({table_name});" @staticmethod - def create_table_if_not_exists(table_name: str, column_name: str, column_type: str, primaryKey=False, autoIncrement=False) -> str: + def create_table_if_not_exists(table_name: str, column_name: str, column_type: str, column_default: str | None = None, primaryKey=False, autoIncrement=False) -> str: return f""" - CREATE TABLE IF NOT EXISTS {table_name} ({column_name} {column_type} {'PRIMARY KEY' if primaryKey else ''} {'AUTOINCREMENT' if autoIncrement else ''}) + CREATE TABLE IF NOT EXISTS {table_name} ({column_name} {column_type} {'PRIMARY KEY' if primaryKey else ''} {'AUTOINCREMENT' if autoIncrement else ''} {'DEFAULT ' + column_default if column_default else ''}) """ @staticmethod - def add_column(table_name: str, column_name: str, column_type: str) -> str: + def add_column(table_name: str, column_name: str, column_type: str, column_default: str) -> str: return f""" - ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type} + ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type} {'DEFAULT ' + column_default if column_default else ''} """ @staticmethod @@ -160,6 +160,28 @@ def parse(type_: Any) -> str: # Not Supported raise TypeError(f"Type `{str(type_)}` not supported.") + @staticmethod + def add_punctuation(data: any) -> str: + """ + Convert data to data that can be insert into a sql, + for example if data is str, return 'data' as a string, + if data is a number, return this number + """ + if data is None: + return None + + if isinstance(data, str): + return f"'{data}'" + elif isinstance(data, bool): + return str(data).lower() + elif isinstance(data, (int, float)): + return str(data) + elif isinstance(data, bytes): + return f"x'{data.hex()}'" # Dont know if this is correct + else: + raise TypeError(f"Type `{str(type(data))}` not supported.") + + @staticmethod def connect(db_name: str, **kwargs) -> Driver_SQLite.Conn: return sqlite3.connect(db_name, **kwargs) diff --git a/MercurySQL/errors/driver_errors.py b/MercurySQL/errors/driver_errors.py index 3cf7a25..c9b9c5d 100644 --- a/MercurySQL/errors/driver_errors.py +++ b/MercurySQL/errors/driver_errors.py @@ -7,3 +7,14 @@ def __init__(self, driver: str, dv: str, cv: str): """ self.message = f"Driver `{driver}`(v{dv}) is not supported by MercurySQL(v{cv}). Please update your Driver/MercurySQL to the latest version." super().__init__(self.message) + +class TypeNotMatchError(Exception): + def __init__(self, value: str, value_type: str, column: str, column_type: str) -> None: + """ + :param value: The value of the param. + :param value_type: The type of the param. + :param column: The name of the column. + :param column_type: The type of the column. + """ + self.message = f"The param `{value}` is {value_type} while the type of the column `{column}` is {column_type}." + super().__init__(self.message) \ No newline at end of file diff --git a/MercurySQL/gensql/database.py b/MercurySQL/gensql/database.py index c969e61..0292917 100644 --- a/MercurySQL/gensql/database.py +++ b/MercurySQL/gensql/database.py @@ -22,6 +22,9 @@ from .table import Table +from ..drivers import sqlite + + # ========== Class Decorations ========== class DataBase: pass diff --git a/MercurySQL/gensql/table.py b/MercurySQL/gensql/table.py index 3b5d1e8..9a8ab75 100644 --- a/MercurySQL/gensql/table.py +++ b/MercurySQL/gensql/table.py @@ -207,7 +207,7 @@ def select(self, exp: Exp = None, selection: str = "*") -> QueryResult: return QueryResult(self, exp, selection) def newColumn( - self, name: str, type_: Any, force=False, primaryKey=False, autoIncrement=False + self, name: str, type_: Any, default = None, force=False, primaryKey=False, autoIncrement=False ) -> None: """ Add a new column to the table. @@ -242,8 +242,12 @@ def newColumn( raise DuplicateError(f"Column `{name}` already exists.") else: return + + if default and type(default) != type_: + raise TypeNotMatchError(default, type(default), name, type_) type_ = self.driver.TypeParser.parse(type_) + default = self.driver.TypeParser.add_punctuation(default) if self.isEmpty: # create it first @@ -251,13 +255,14 @@ def newColumn( self.table_name, name, type_, + default, primaryKey=primaryKey, autoIncrement=autoIncrement, ) self.db.do(cmd) self.isEmpty = False else: - cmd = self.driver.APIs.gensql.add_column(self.table_name, name, type_) + cmd = self.driver.APIs.gensql.add_column(self.table_name, name, type_, default) self.db.do(cmd) if primaryKey: @@ -267,7 +272,7 @@ def newColumn( self.columnsType[name] = type_ def struct( - self, columns: dict, skipError=True, primaryKey: str = None, autoIncrement=False, force=True + self, columns: dict, skipError=True, primaryKey: str = None, autoIncrement=False, force=True, rebuild=False ) -> None: """ Set the structure of the table. @@ -278,6 +283,8 @@ def struct( :type skipError: bool :param primaryKey: The primary key of the table. :type primaryKey: str + :param rebuild: Whether to rebuild the table. + :type rebuild: bool Example Usage: @@ -295,26 +302,40 @@ def struct( skipError = skipError and force for name, type_ in columns.items(): + default_value = None + + + # 支持 type_ 为 [str, "默认值"] 的写法 + if isinstance(type_, list): + default_value = type_[1] + type_ = type_[0] + type_origin = type_ type_ = self.driver.TypeParser.parse(type_) isPrimaryKey = name == primaryKey if name in self.columns: - if type_.lower() != self.columnsType[name].lower(): + if rebuild: + print(f"Testing function, still constructing...") + # TODO: 这个Rebuild理论上会重新创建一个表,但是对应的底层代码还没写完 + # self.delColumn(name) + elif type_.lower() != self.columnsType[name].lower(): raise ConfilictError( f"Column `{name}` with different types (`{self.columnsType[name]}`) already exists. While trying to add column `{name}` with type `{type_}`." ) elif not skipError: raise DuplicateError( - f"Column `{name}` already exists. You can use `force=True` to avoid this error." + f"Column `{name}` already exists. You can use `skipError=True` to avoid this error." ) - else: - self.newColumn( - name, - type_origin, - primaryKey=isPrimaryKey, - autoIncrement=autoIncrement, - ) + + # Raise 错误后不会执行后续代码,所以这里可以不用 Else + self.newColumn( + name, + type_origin, + default=default_value, + primaryKey=isPrimaryKey, + autoIncrement=autoIncrement, + ) def delColumn(self, name: str) -> None: if name not in self.columns: @@ -326,7 +347,8 @@ def delColumn(self, name: str) -> None: else: # delete the column self.columns.remove(name) - cmd = self.driver.APIs.gensql.drop_column(name) + # 这一行有Bug,修复了 + cmd = self.driver.APIs.gensql.drop_column(table_name=self.table_name,column_name=name) self.db.do(cmd) def setPrimaryKey(self, keyname: str, keytype: str) -> None: @@ -365,8 +387,16 @@ def insert(self, __auto=False, **kwargs) -> None: table.insert(id=1, name='Bernie', age=15, __auto=True) """ - # get keys and clean them - keys = list(kwargs.keys()) + + # TODO: 为了兼容v1.0中的json格式传参的问题,v1.0需要修复 + # if __auto is a dict + if isinstance(__auto, dict): + keys = list(__auto.keys()) + kwargs = __auto + else: + # get keys and clean them + keys = list(kwargs.keys()) + if "__auto" in keys: __auto = kwargs["__auto"] keys.remove("__auto") diff --git a/requirements.txt b/requirements.txt index ef634fa..023d895 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,6 @@ sphinx_rtd_theme # for release wheel + +# MySQL +mysql.connector