瀏覽代碼

bug fix: support for netezza

Steve Nyemba 4 年之前
父節點
當前提交
7c2e945996
共有 2 個文件被更改,包括 44 次插入15 次删除
  1. 2 2
      setup.py
  2. 42 13
      transport/sql.py

+ 2 - 2
setup.py

@@ -8,12 +8,12 @@ def read(fname):
     return open(os.path.join(os.path.dirname(__file__), fname)).read() 
 args    = {
     "name":"data-transport",
-    "version":"1.3.8.4",
+    "version":"1.3.8.6.1",
     "author":"The Phi Technology LLC","author_email":"info@the-phi.com",
     "license":"MIT",
     "packages":["transport"]}
 args["keywords"]=['mongodb','couchdb','rabbitmq','file','read','write','s3','sqlite']
-args["install_requires"] = ['pymongo','numpy','cloudant','pika','boto3','boto','pyarrow','google-cloud-bigquery','google-cloud-bigquery-storage','flask-session','smart_open','botocore','psycopg2-binary','mysql-connector-python']
+args["install_requires"] = ['pymongo','numpy','cloudant','pika','nzpy','boto3','boto','pyarrow','google-cloud-bigquery','google-cloud-bigquery-storage','flask-session','smart_open','botocore','psycopg2-binary','mysql-connector-python']
 args["url"] =   "https://healthcareio.the-phi.com/git/code/transport.git"
 
 if sys.version_info[0] == 2 :

+ 42 - 13
transport/sql.py

@@ -22,23 +22,32 @@ from google.cloud import bigquery as bq
 from multiprocessing import Lock
 import pandas as pd
 import numpy as np
+import nzpy as nz   #--- netezza drivers
 import copy
 
 
 class SQLRW :
-    PROVIDERS = {"postgresql":"5432","redshift":"5432","mysql":"3306","mariadb":"3306"}
-    DRIVERS  = {"postgresql":pg,"redshift":pg,"mysql":my,"mariadb":my}
+    PROVIDERS = {"postgresql":"5432","redshift":"5432","mysql":"3306","mariadb":"3306","netezza":5480}
+    DRIVERS  = {"postgresql":pg,"redshift":pg,"mysql":my,"mariadb":my,"netezza":nz}
+    REFERENCE = {
+        "netezza":{"port":5480,"handler":nz,"dtype":"VARCHAR(512)"},
+        "postgresql":{"port":5432,"handler":pg,"dtype":"VARCHAR"},
+        "redshift":{"port":5432,"handler":pg,"dtype":"VARCHAR"},
+        "mysql":{"port":3360,"handler":my,"dtype":"VARCHAR(256)"},
+        "mariadb":{"port":3360,"handler":my,"dtype":"VARCHAR(256)"},
+        }
     def __init__(self,**_args):
         
         
         _info = {}
-        _info['dbname']     = _args['db']
+        _info['dbname'] = _args['db'] if 'db' in _args else _args['database']
         self.table      = _args['table']
         self.fields     = _args['fields'] if 'fields' in _args else []
-        
+        _provider       = _args['provider']
         if 'host' in _args :
             _info['host'] = 'localhost' if 'host' not in _args else _args['host']
-            _info['port'] = SQLWriter.PROVIDERS[_args['provider']] if 'port' not in _args else _args['port']
+            # _info['port'] = SQLWriter.PROVIDERS[_args['provider']] if 'port' not in _args else _args['port']
+            _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
         
         if 'username' in _args or 'user' in _args:
             key = 'username' if 'username' in _args else 'user'
@@ -46,7 +55,13 @@ class SQLRW :
             _info['password'] = _args['password']
         #
         # We need to load the drivers here to see what we are dealing with ...
-        _handler = SQLWriter.DRIVERS[_args['provider']]
+        # _handler = SQLWriter.DRIVERS[_args['provider']]
+        _handler = SQLWriter.REFERENCE[_provider]['handler']
+        self._dtype = SQLWriter.REFERENCE[_provider]['dtype'] if 'dtype' not in _args else _args['dtype']
+        if _handler == nz :
+            _info['database'] = _info['dbname']
+            _info['securityLevel'] = 0
+            del _info['dbname']
         self.conn = _handler.connect(**_info)
     
     def isready(self):
@@ -118,11 +133,13 @@ class SQLWriter(SQLRW,Writer):
 
     def make(self,fields):
         self.fields = fields
-        sql = " ".join(["CREATE TABLE",self.table," (", ",".join(fields),")"])
+        
+        sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
         cursor = self.conn.cursor()
         try:
             cursor.execute(sql)
         except Exception as e :
+            print (e)
             pass
         finally:
             cursor.close()
@@ -136,12 +153,14 @@ class SQLWriter(SQLRW,Writer):
             _fields = info.keys() if type(info) == dict else info[0].keys()
             _fields = list (_fields)
             self.init(_fields)
-
+        #
+        # @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
+        #
         if type(info) != list :
             info = [info]        
         cursor = self.conn.cursor()
         try:
-            _sql = "INSERT INTO :table (:fields) values (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
+            _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
             if self._inspect :
                 for _row in info :
                     fields = list(_row.keys())
@@ -161,15 +180,19 @@ class SQLWriter(SQLRW,Writer):
                 pass
             else:
                 _fields = ",".join(self.fields)
-                _sql = _sql.replace(":fields",_fields)
-                _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
+                # _sql = _sql.replace(":fields",_fields)
+                # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
+                _sql = _sql.replace("(:fields)","")
+                values = ", ".join('?'*len(self.fields))
+                _sql = _sql.replace(":values",values)
                 
                 # for row in info :
                 #     values = ["'".join(["",value,""]) if not str(value).isnumeric() else value for value in row.values()]
                 cursor.executemany(_sql,info)   
+            
             # self.conn.commit()
         except Exception as e:
-            print (e) 
+            pass
         finally:
             self.conn.commit()
             cursor.close()
@@ -265,7 +288,13 @@ class BQWriter(BigQuery,Writer):
             _df.to_gbq(**self.mode) #if_exists='append',destination_table=partial,credentials=credentials,chunksize=90000)	
             
         pass
-# import transport    
+import transport    
+try:
+    _args = {'type':'sql.SQLWriter','args':{'provider':'netezza','host':'ori-netezza.vumc.org','table':'IBM_CCS_DX','username':'nyembsl1','password':'Innovat10n','database':'MALIN_OMOP_RD'}}
+    df = pd
+    reader = SQLReader(**_args['args'])
+except Exception as error :
+    print (error)
 # reader = transport.factory.instance(type="sql.BQReader",args={"service_key":"/home/steve/dev/google-cloud-sdk/accounts/curation-prod.json"})
 # _df = reader.read(sql="select  * from `2019q1r4_combined.person` limit 10")
 # writer = transport.factory.instance(type="sql.BQWriter",args={"service_key":"/home/steve/dev/google-cloud-sdk/accounts/curation-prod.json"})