Procházet zdrojové kódy

bug fix: write problems, using drivers

Steve Nyemba před 3 roky
rodič
revize
fd9442c298
2 změnil soubory, kde provedl 32 přidání a 19 odebrání
  1. 1 1
      transport/common.py
  2. 31 18
      transport/sql.py

+ 1 - 1
transport/common.py

@@ -41,7 +41,7 @@ class Reader (IO):
 	"""
 	def __init__(self):
 		pass
-	def meta(self):
+	def meta(self,**_args):
 		"""
 		This function is intended to return meta-data associated with what has just been read
 		@return object of meta data information associated with the content of the store

+ 31 - 18
transport/sql.py

@@ -46,6 +46,7 @@ class SQLRW :
         _info['dbname'] = _args['db'] if 'db' in _args else _args['database']
         self.table      = _args['table'] if 'table' in _args else None
         self.fields     = _args['fields'] if 'fields' in _args else []
+        self.schema = _args['schema'] if 'schema' in _args else ''
         
         self._provider       = _args['provider'] if 'provider' in _args else None
         # _info['host'] = 'localhost' if 'host' not in _args else _args['host']
@@ -83,10 +84,16 @@ class SQLRW :
         
         self.conn = _handler.connect(**_info)
         self._engine = _args['sqlalchemy']  if 'sqlalchemy' in _args else None
+    def meta(self,**_args):
+        return []
+    def _tablename(self,name) :
+        
+        return self.schema +'.'+name if self.schema not in [None, ''] and '.' not in name else name 
     def has(self,**_args):
         found = False
         try:
-            table = _args['table']
+            
+            table = self._tablename(_args['table'])
             sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
             if self._engine :
                 _conn = self._engine.connect()
@@ -172,20 +179,23 @@ class SQLWriter(SQLRW,Writer):
     def init(self,fields=None):
         if not fields :
             try:                
-                self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
+                table = self._tablename(self.table)
+                self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",table),self.conn).columns.tolist()
             finally:
                 pass
         else:
             self.fields = fields;
 
     def make(self,**_args):
-
+        table = self._tablename(self.table) if 'table' not in _args else self._tablename(_args['table'])
         if 'fields' in _args :
-            fields = _args['fields']            
-            sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
+            fields = _args['fields']  
+            # table = self._tablename(self.table)          
+            sql = " ".join(["CREATE TABLE",table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
+            print (sql)
         else:
             schema = _args['schema']
-            N = len(schema)
+            
             _map = _args['map'] if 'map' in _args else {}
             sql = [] # ["CREATE TABLE ",_args['table'],"("]
             for _item in schema :
@@ -194,7 +204,8 @@ class SQLWriter(SQLRW,Writer):
                     _type = _map[_type]
                 sql = sql + [" " .join([_item['name'], ' ',_type])]
             sql = ",".join(sql)
-            sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"]
+            # table = self._tablename(_args['table'])
+            sql = ["CREATE TABLE ",table,"( ",sql," )"]
             sql = " ".join(sql)
             # sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
         cursor = self.conn.cursor()
@@ -235,8 +246,8 @@ class SQLWriter(SQLRW,Writer):
         #     info = [info]  if type(info) == dict else info.values.tolist()       
         cursor = self.conn.cursor()
         try:
-            
-            _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
+            table = self._tablename(self.table)
+            _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",table) #.replace(":table",self.table).replace(":fields",_fields)
             if self._inspect :
                 for _row in info :
                     fields = list(_row.keys())
@@ -285,16 +296,18 @@ class SQLWriter(SQLRW,Writer):
                     _info = pd.DataFrame(info)
             
                 
-                if self._engine :
-                    # pd.to_sql(_info,self._engine)
-                    _info.to_sql(self.table,self._engine,if_exists='append',index=False)
-                else:
-                    _fields = ",".join(self.fields)
-                    _sql = _sql.replace(":fields",_fields)
-                    values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
-                    _sql = _sql.replace(":values",values)
+                # if self._engine :
+                #     # pd.to_sql(_info,self._engine)
+                #     print (_info.columns.tolist())
+                #     rows = _info.to_sql(table,self._engine,if_exists='append',index=False)
+                #     print ([rows])
+                # else:
+                _fields = ",".join(self.fields)
+                _sql = _sql.replace(":fields",_fields)
+                values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
+                _sql = _sql.replace(":values",values)
                     
-                    cursor.executemany(_sql,_info.values.tolist())  
+                cursor.executemany(_sql,_info.values.tolist())  
                 # cursor.commit() 
             
             # self.conn.commit()