Browse Source

Bug fix: sqlalchemy facilities added

Steve Nyemba 3 years ago
parent
commit
14a551e57b
2 changed files with 130 additions and 44 deletions
  1. 35 4
      transport/__init__.py
  2. 95 40
      transport/sql.py

+ 35 - 4
transport/__init__.py

@@ -26,7 +26,7 @@ import numpy 	as np
 import json
 import json
 import importlib 
 import importlib 
 import sys 
 import sys 
-
+import sqlalchemy
 if sys.version_info[0] > 2 : 
 if sys.version_info[0] > 2 : 
     from transport.common import Reader, Writer #, factory
     from transport.common import Reader, Writer #, factory
     from transport import disk
     from transport import disk
@@ -59,8 +59,8 @@ class factory :
         "postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
         "postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
         "redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
         "redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
         "bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}},
         "bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}},
-        "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
-        "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
+        "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
+        "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
 		"mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}},		
 		"mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}},		
 		"couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}},		
 		"couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}},		
         "netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}}
         "netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}}
@@ -137,7 +137,38 @@ def instance(**_args):
 			pointer = factory.PROVIDERS[provider]['class'][_id] 
 			pointer = factory.PROVIDERS[provider]['class'][_id] 
 		else:
 		else:
 			pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter
 			pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter
-		
+		#
+		# Let us try to establish an sqlalchemy wrapper
+		try:
+			host = ''
+			if provider not in ['bigquery','mongodb','couchdb','sqlite'] :
+				#
+				# In these cases we are assuming RDBMS and thus would exclude NoSQL and BigQuery
+				username = args['username'] if 'username' in args else ''
+				password = args['password'] if 'password' in args else ''
+				if username == '' :
+					account = ''
+				else:
+					account = username + ':'+password+'@'
+				host = args['host'] 
+				if 'port' in args :
+					host = host+":"+str(args['port'])
+				
+				database =  args['database']	
+			elif provider == 'sqlite':
+				account = ''
+				host = ''
+				database = args['path'] if 'path' in args else args['database']
+			if provider not in ['mongodb','couchdb','bigquery'] :
+				uri = ''.join([provider,"://",account,host,'/',database])
+				
+				e = sqlalchemy.create_engine (uri)
+				args['sqlalchemy'] = e 
+			#
+			# @TODO: Include handling of bigquery with SQLAlchemy
+		except Exception as e:
+			print (e)
+
 		return pointer(**args)
 		return pointer(**args)
 
 
 	return None
 	return None

+ 95 - 40
transport/sql.py

@@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
 import psycopg2 as pg
 import psycopg2 as pg
 import mysql.connector as my
 import mysql.connector as my
 import sys
 import sys
+
+import sqlalchemy
 if sys.version_info[0] > 2 : 
 if sys.version_info[0] > 2 : 
     from transport.common import Reader, Writer #, factory
     from transport.common import Reader, Writer #, factory
 else:
 else:
@@ -44,7 +46,8 @@ class SQLRW :
         _info['dbname'] = _args['db'] if 'db' in _args else _args['database']
         _info['dbname'] = _args['db'] if 'db' in _args else _args['database']
         self.table      = _args['table'] if 'table' in _args else None
         self.table      = _args['table'] if 'table' in _args else None
         self.fields     = _args['fields'] if 'fields' in _args else []
         self.fields     = _args['fields'] if 'fields' in _args else []
-        # _provider       = _args['provider']
+        
+        self._provider       = _args['provider'] if 'provider' in _args else None
         # _info['host'] = 'localhost' if 'host' not in _args else _args['host']
         # _info['host'] = 'localhost' if 'host' not in _args else _args['host']
         # _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
         # _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
 
 
@@ -59,7 +62,7 @@ class SQLRW :
         if 'username' in _args or 'user' in _args:
         if 'username' in _args or 'user' in _args:
             key = 'username' if 'username' in _args else 'user'
             key = 'username' if 'username' in _args else 'user'
             _info['user'] = _args[key]
             _info['user'] = _args[key]
-            _info['password'] = _args['password']
+            _info['password'] = _args['password'] if 'password' in _args else ''
         #
         #
         # We need to load the drivers here to see what we are dealing with ...
         # We need to load the drivers here to see what we are dealing with ...
         
         
@@ -74,17 +77,29 @@ class SQLRW :
             _info['database'] = _info['dbname']
             _info['database'] = _info['dbname']
             _info['securityLevel'] = 0
             _info['securityLevel'] = 0
             del _info['dbname']
             del _info['dbname']
+        if _handler == my :
+            _info['database'] = _info['dbname']
+            del _info['dbname']
+        
         self.conn = _handler.connect(**_info)
         self.conn = _handler.connect(**_info)
+        self._engine = _args['sqlalchemy']  if 'sqlalchemy' in _args else None
     def has(self,**_args):
     def has(self,**_args):
         found = False
         found = False
         try:
         try:
             table = _args['table']
             table = _args['table']
             sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
             sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
-            found = pd.read_sql(sql,self.conn).shape[0] 
+            if self._engine :
+                _conn = self._engine.connect()
+            else:
+                _conn = self.conn
+            found = pd.read_sql(sql,_conn).shape[0] 
             found = True
             found = True
 
 
         except Exception as e:
         except Exception as e:
             pass
             pass
+        finally:
+            if self._engine :
+                _conn.close()
         return found
         return found
     def isready(self):
     def isready(self):
         _sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table)
         _sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table)
@@ -104,7 +119,8 @@ class SQLRW :
         try:
         try:
             if "select" in _sql.lower() :
             if "select" in _sql.lower() :
                 cursor.close()
                 cursor.close()
-                return pd.read_sql(_sql,self.conn)
+                _conn = self._engine.connect() if self._engine else self.conn
+                return pd.read_sql(_sql,_conn)
             else:
             else:
                 # Executing a command i.e no expected return values ...
                 # Executing a command i.e no expected return values ...
                 cursor.execute(_sql)
                 cursor.execute(_sql)
@@ -122,7 +138,8 @@ class SQLRW :
             pass
             pass
 class SQLReader(SQLRW,Reader) :
 class SQLReader(SQLRW,Reader) :
     def __init__(self,**_args):
     def __init__(self,**_args):
-        super().__init__(**_args)     
+        super().__init__(**_args) 
+        
     def read(self,**_args):
     def read(self,**_args):
         if 'sql' in _args :            
         if 'sql' in _args :            
             _sql = (_args['sql'])
             _sql = (_args['sql'])
@@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer):
         # NOTE: Proper data type should be set on the target system if their source is unclear.
         # NOTE: Proper data type should be set on the target system if their source is unclear.
         self._inspect = False if 'inspect' not in _args else _args['inspect']
         self._inspect = False if 'inspect' not in _args else _args['inspect']
         self._cast = False if 'cast' not in _args else _args['cast']
         self._cast = False if 'cast' not in _args else _args['cast']
+        
     def init(self,fields=None):
     def init(self,fields=None):
         if not fields :
         if not fields :
             try:                
             try:                
-                self.fields = pd.read_sql("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
+                self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
             finally:
             finally:
                 pass
                 pass
         else:
         else:
             self.fields = fields;
             self.fields = fields;
 
 
-    def make(self,fields):
-        self.fields = fields
-        
-        sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
+    def make(self,**_args):
+
+        if 'fields' in _args :
+            fields = _args['fields']            
+            sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
+        else:
+            schema = _args['schema']
+            N = len(schema)
+            _map = _args['map'] if 'map' in _args else {}
+            sql = [] # ["CREATE TABLE ",_args['table'],"("]
+            for _item in schema :
+                _type = _item['type']
+                if _type in _map :
+                    _type = _map[_type]
+                sql = sql + [" " .join([_item['name'], ' ',_type])]
+            sql = ",".join(sql)
+            sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"]
+            sql = " ".join(sql)
+            # sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
         cursor = self.conn.cursor()
         cursor = self.conn.cursor()
         try:
         try:
+            
             cursor.execute(sql)
             cursor.execute(sql)
         except Exception as e :
         except Exception as e :
             print (e)
             print (e)
+            print (sql)
             pass
             pass
         finally:
         finally:
-            cursor.close()
+            # cursor.close()
+            self.conn.commit()
+            pass
     def write(self,info):
     def write(self,info):
         """
         """
         :param info writes a list of data to a given set of fields
         :param info writes a list of data to a given set of fields
@@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer):
             elif type(info) == dict :
             elif type(info) == dict :
                 _fields = info.keys()
                 _fields = info.keys()
             elif type(info) == pd.DataFrame :
             elif type(info) == pd.DataFrame :
-                _fields = info.columns
+                _fields = info.columns.tolist()
 
 
             # _fields = info.keys() if type(info) == dict else info[0].keys()
             # _fields = info.keys() if type(info) == dict else info[0].keys()
             _fields = list (_fields)
             _fields = list (_fields)
@@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer):
         #
         #
         # @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
         # @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
         #
         #
-        if type(info) != list :
-            #
-            # We are assuming 2 cases i.e dict or pd.DataFrame
-            info = [info]  if type(info) == dict else info.values.tolist()       
+        # if type(info) != list :
+        #     #
+        #     # We are assuming 2 cases i.e dict or pd.DataFrame
+        #     info = [info]  if type(info) == dict else info.values.tolist()       
         cursor = self.conn.cursor()
         cursor = self.conn.cursor()
         try:
         try:
+            
             _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
             _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
             if self._inspect :
             if self._inspect :
                 for _row in info :
                 for _row in info :
@@ -223,34 +261,49 @@ class SQLWriter(SQLRW,Writer):
 
 
                 pass
                 pass
             else:
             else:
-                _fields = ",".join(self.fields)
+                
                 # _sql = _sql.replace(":fields",_fields)
                 # _sql = _sql.replace(":fields",_fields)
                 # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
                 # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
                 # _sql = _sql.replace("(:fields)","")
                 # _sql = _sql.replace("(:fields)","")
-                _sql = _sql.replace(":fields",_fields)
-                values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
-                _sql = _sql.replace(":values",values)
-                if type(info) == pd.DataFrame :
-                    _info = info[self.fields].values.tolist()
-                elif  type(info) == dict :
-                    _info = info.values()
-                else:
-                    # _info = []
+                
+                # _sql = _sql.replace(":values",values)
+                # if type(info) == pd.DataFrame :
+                #     _info = info[self.fields].values.tolist()
+                    
+                # elif  type(info) == dict :
+                #     _info = info.values()
+                # else:
+                #     # _info = []
 
 
-                    _info = pd.DataFrame(info)[self.fields].values.tolist()
-                    # for row in info :
-                        
-                    #     if type(row) == dict :
-                    #         _info.append( list(row.values()))
-                cursor.executemany(_sql,_info)   
+                #     _info = pd.DataFrame(info)[self.fields].values.tolist()
+                    # _info = pd.DataFrame(info).to_dict(orient='records')
+                if type(info) == list :
+                    _info = pd.DataFrame(info)
+                elif type(info) == dict :
+                    _info = pd.DataFrame([info])
+                else:
+                    _info = pd.DataFrame(info)
+            
+                
+                if self._engine :
+                    # pd.to_sql(_info,self._engine)
+                    _info.to_sql(self.table,self._engine,if_exists='append',index=False)
+                else:
+                    _fields = ",".join(self.fields)
+                    _sql = _sql.replace(":fields",_fields)
+                    values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
+                    _sql = _sql.replace(":values",values)
+                    
+                    cursor.executemany(_sql,_info.values.tolist())  
+                # cursor.commit() 
             
             
             # self.conn.commit()
             # self.conn.commit()
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
             pass
             pass
         finally:
         finally:
-            self.conn.commit()
-            cursor.close()
+            self.conn.commit()            
+            # cursor.close()
             pass
             pass
     def close(self):
     def close(self):
         try:
         try:
@@ -265,6 +318,7 @@ class BigQuery:
         self.path = path
         self.path = path
         self.dtypes = _args['dtypes'] if 'dtypes' in _args else None
         self.dtypes = _args['dtypes'] if 'dtypes' in _args else None
         self.table = _args['table'] if 'table' in _args else None
         self.table = _args['table'] if 'table' in _args else None
+        self.client = bq.Client.from_service_account_json(self.path)
     def meta(self,**_args):
     def meta(self,**_args):
         """
         """
         This function returns meta data for a given table or query with dataset/table properly formatted
         This function returns meta data for a given table or query with dataset/table properly formatted
@@ -272,16 +326,16 @@ class BigQuery:
         :param sql      sql query to be pulled,
         :param sql      sql query to be pulled,
         """
         """
         table = _args['table']
         table = _args['table']
-        client = bq.Client.from_service_account_json(self.path)
-        ref     = client.dataset(self.dataset).table(table)
-        return client.get_table(ref).schema
+        
+        ref     = self.client.dataset(self.dataset).table(table)
+        return self.client.get_table(ref).schema
     def has(self,**_args):
     def has(self,**_args):
         found = False
         found = False
         try:
         try:
             found = self.meta(**_args) is not None
             found = self.meta(**_args) is not None
         except Exception as e:
         except Exception as e:
             pass
             pass
-            return found
+        return found
 class BQReader(BigQuery,Reader) :
 class BQReader(BigQuery,Reader) :
     def __init__(self,**_args):
     def __init__(self,**_args):
         
         
@@ -304,8 +358,9 @@ class BQReader(BigQuery,Reader) :
         if (':dataset' in SQL or ':DATASET' in SQL)  and self.dataset:
         if (':dataset' in SQL or ':DATASET' in SQL)  and self.dataset:
             SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset)
             SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset)
         _info = {'credentials':self.credentials,'dialect':'standard'}       
         _info = {'credentials':self.credentials,'dialect':'standard'}       
-        return pd.read_gbq(SQL,**_info) if SQL else None    
-        # return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None
+        return pd.read_gbq(SQL,**_info) if SQL else None  
+        # return self.client.query(SQL).to_dataframe() if SQL else None
+        
 
 
 class BQWriter(BigQuery,Writer):
 class BQWriter(BigQuery,Writer):
     lock = Lock()
     lock = Lock()