Przeglądaj źródła

Bug fix: sqlalchemy facilities added

Steve Nyemba 3 lat temu
rodzic
commit
14a551e57b
2 zmienionych plików z 130 dodań i 44 usunięć
  1. 35 4
      transport/__init__.py
  2. 95 40
      transport/sql.py

+ 35 - 4
transport/__init__.py

@@ -26,7 +26,7 @@ import numpy 	as np
 import json
 import importlib 
 import sys 
-
+import sqlalchemy
 if sys.version_info[0] > 2 : 
     from transport.common import Reader, Writer #, factory
     from transport import disk
@@ -59,8 +59,8 @@ class factory :
         "postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
         "redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
         "bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}},
-        "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
-        "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
+        "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
+        "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
 		"mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}},		
 		"couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}},		
         "netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}}
@@ -137,7 +137,38 @@ def instance(**_args):
 			pointer = factory.PROVIDERS[provider]['class'][_id] 
 		else:
 			pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter
-		
+		#
+		# Let us try to establish an sqlalchemy wrapper
+		try:
+			host = ''
+			if provider not in ['bigquery','mongodb','couchdb','sqlite'] :
+				#
+				# In these cases we are assuming RDBMS and thus would exclude NoSQL and BigQuery
+				username = args['username'] if 'username' in args else ''
+				password = args['password'] if 'password' in args else ''
+				if username == '' :
+					account = ''
+				else:
+					account = username + ':'+password+'@'
+				host = args['host'] 
+				if 'port' in args :
+					host = host+":"+str(args['port'])
+				
+				database =  args['database']	
+			elif provider == 'sqlite':
+				account = ''
+				host = ''
+				database = args['path'] if 'path' in args else args['database']
+			if provider not in ['mongodb','couchdb','bigquery'] :
+				uri = ''.join([provider,"://",account,host,'/',database])
+				
+				e = sqlalchemy.create_engine (uri)
+				args['sqlalchemy'] = e 
+			#
+			# @TODO: Include handling of bigquery with SQLAlchemy
+		except Exception as e:
+			print (e)
+
 		return pointer(**args)
 
 	return None

+ 95 - 40
transport/sql.py

@@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
 import psycopg2 as pg
 import mysql.connector as my
 import sys
+
+import sqlalchemy
 if sys.version_info[0] > 2 : 
     from transport.common import Reader, Writer #, factory
 else:
@@ -44,7 +46,8 @@ class SQLRW :
         _info['dbname'] = _args['db'] if 'db' in _args else _args['database']
         self.table      = _args['table'] if 'table' in _args else None
         self.fields     = _args['fields'] if 'fields' in _args else []
-        # _provider       = _args['provider']
+        
+        self._provider       = _args['provider'] if 'provider' in _args else None
         # _info['host'] = 'localhost' if 'host' not in _args else _args['host']
         # _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
 
@@ -59,7 +62,7 @@ class SQLRW :
         if 'username' in _args or 'user' in _args:
             key = 'username' if 'username' in _args else 'user'
             _info['user'] = _args[key]
-            _info['password'] = _args['password']
+            _info['password'] = _args['password'] if 'password' in _args else ''
         #
         # We need to load the drivers here to see what we are dealing with ...
         
@@ -74,17 +77,29 @@ class SQLRW :
             _info['database'] = _info['dbname']
             _info['securityLevel'] = 0
             del _info['dbname']
+        if _handler == my :
+            _info['database'] = _info['dbname']
+            del _info['dbname']
+        
         self.conn = _handler.connect(**_info)
+        self._engine = _args['sqlalchemy']  if 'sqlalchemy' in _args else None
     def has(self,**_args):
         found = False
         try:
             table = _args['table']
             sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
-            found = pd.read_sql(sql,self.conn).shape[0] 
+            if self._engine :
+                _conn = self._engine.connect()
+            else:
+                _conn = self.conn
+            found = pd.read_sql(sql,_conn).shape[0] 
             found = True
 
         except Exception as e:
             pass
+        finally:
+            if self._engine :
+                _conn.close()
         return found
     def isready(self):
         _sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table)
@@ -104,7 +119,8 @@ class SQLRW :
         try:
             if "select" in _sql.lower() :
                 cursor.close()
-                return pd.read_sql(_sql,self.conn)
+                _conn = self._engine.connect() if self._engine else self.conn
+                return pd.read_sql(_sql,_conn)
             else:
                 # Executing a command i.e no expected return values ...
                 cursor.execute(_sql)
@@ -122,7 +138,8 @@ class SQLRW :
             pass
 class SQLReader(SQLRW,Reader) :
     def __init__(self,**_args):
-        super().__init__(**_args)     
+        super().__init__(**_args) 
+        
     def read(self,**_args):
         if 'sql' in _args :            
             _sql = (_args['sql'])
@@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer):
         # NOTE: Proper data type should be set on the target system if their source is unclear.
         self._inspect = False if 'inspect' not in _args else _args['inspect']
         self._cast = False if 'cast' not in _args else _args['cast']
+        
     def init(self,fields=None):
         if not fields :
             try:                
-                self.fields = pd.read_sql("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
+                self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
             finally:
                 pass
         else:
             self.fields = fields;
 
-    def make(self,fields):
-        self.fields = fields
-        
-        sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
+    def make(self,**_args):
+
+        if 'fields' in _args :
+            fields = _args['fields']            
+            sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
+        else:
+            schema = _args['schema']
+            N = len(schema)
+            _map = _args['map'] if 'map' in _args else {}
+            sql = [] # ["CREATE TABLE ",_args['table'],"("]
+            for _item in schema :
+                _type = _item['type']
+                if _type in _map :
+                    _type = _map[_type]
+                sql = sql + [" " .join([_item['name'], ' ',_type])]
+            sql = ",".join(sql)
+            sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"]
+            sql = " ".join(sql)
+            # sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
         cursor = self.conn.cursor()
         try:
+            
             cursor.execute(sql)
         except Exception as e :
             print (e)
+            print (sql)
             pass
         finally:
-            cursor.close()
+            # cursor.close()
+            self.conn.commit()
+            pass
     def write(self,info):
         """
         :param info writes a list of data to a given set of fields
@@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer):
             elif type(info) == dict :
                 _fields = info.keys()
             elif type(info) == pd.DataFrame :
-                _fields = info.columns
+                _fields = info.columns.tolist()
 
             # _fields = info.keys() if type(info) == dict else info[0].keys()
             _fields = list (_fields)
@@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer):
         #
         # @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
         #
-        if type(info) != list :
-            #
-            # We are assuming 2 cases i.e dict or pd.DataFrame
-            info = [info]  if type(info) == dict else info.values.tolist()       
+        # if type(info) != list :
+        #     #
+        #     # We are assuming 2 cases i.e dict or pd.DataFrame
+        #     info = [info]  if type(info) == dict else info.values.tolist()       
         cursor = self.conn.cursor()
         try:
+            
             _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
             if self._inspect :
                 for _row in info :
@@ -223,34 +261,49 @@ class SQLWriter(SQLRW,Writer):
 
                 pass
             else:
-                _fields = ",".join(self.fields)
+                
                 # _sql = _sql.replace(":fields",_fields)
                 # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
                 # _sql = _sql.replace("(:fields)","")
-                _sql = _sql.replace(":fields",_fields)
-                values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
-                _sql = _sql.replace(":values",values)
-                if type(info) == pd.DataFrame :
-                    _info = info[self.fields].values.tolist()
-                elif  type(info) == dict :
-                    _info = info.values()
-                else:
-                    # _info = []
+                
+                # _sql = _sql.replace(":values",values)
+                # if type(info) == pd.DataFrame :
+                #     _info = info[self.fields].values.tolist()
+                    
+                # elif  type(info) == dict :
+                #     _info = info.values()
+                # else:
+                #     # _info = []
 
-                    _info = pd.DataFrame(info)[self.fields].values.tolist()
-                    # for row in info :
-                        
-                    #     if type(row) == dict :
-                    #         _info.append( list(row.values()))
-                cursor.executemany(_sql,_info)   
+                #     _info = pd.DataFrame(info)[self.fields].values.tolist()
+                    # _info = pd.DataFrame(info).to_dict(orient='records')
+                if type(info) == list :
+                    _info = pd.DataFrame(info)
+                elif type(info) == dict :
+                    _info = pd.DataFrame([info])
+                else:
+                    _info = pd.DataFrame(info)
+            
+                
+                if self._engine :
+                    # pd.to_sql(_info,self._engine)
+                    _info.to_sql(self.table,self._engine,if_exists='append',index=False)
+                else:
+                    _fields = ",".join(self.fields)
+                    _sql = _sql.replace(":fields",_fields)
+                    values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
+                    _sql = _sql.replace(":values",values)
+                    
+                    cursor.executemany(_sql,_info.values.tolist())  
+                # cursor.commit() 
             
             # self.conn.commit()
         except Exception as e:
             print(e)
             pass
         finally:
-            self.conn.commit()
-            cursor.close()
+            self.conn.commit()            
+            # cursor.close()
             pass
     def close(self):
         try:
@@ -265,6 +318,7 @@ class BigQuery:
         self.path = path
         self.dtypes = _args['dtypes'] if 'dtypes' in _args else None
         self.table = _args['table'] if 'table' in _args else None
+        self.client = bq.Client.from_service_account_json(self.path)
     def meta(self,**_args):
         """
         This function returns meta data for a given table or query with dataset/table properly formatted
@@ -272,16 +326,16 @@ class BigQuery:
         :param sql      sql query to be pulled,
         """
         table = _args['table']
-        client = bq.Client.from_service_account_json(self.path)
-        ref     = client.dataset(self.dataset).table(table)
-        return client.get_table(ref).schema
+        
+        ref     = self.client.dataset(self.dataset).table(table)
+        return self.client.get_table(ref).schema
     def has(self,**_args):
         found = False
         try:
             found = self.meta(**_args) is not None
         except Exception as e:
             pass
-            return found
+        return found
 class BQReader(BigQuery,Reader) :
     def __init__(self,**_args):
         
@@ -304,8 +358,9 @@ class BQReader(BigQuery,Reader) :
         if (':dataset' in SQL or ':DATASET' in SQL)  and self.dataset:
             SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset)
         _info = {'credentials':self.credentials,'dialect':'standard'}       
-        return pd.read_gbq(SQL,**_info) if SQL else None    
-        # return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None
+        return pd.read_gbq(SQL,**_info) if SQL else None  
+        # return self.client.query(SQL).to_dataframe() if SQL else None
+        
 
 class BQWriter(BigQuery,Writer):
     lock = Lock()