瀏覽代碼

bugfix: ETL multiprocessing

Steve Nyemba 3 年之前
父節點
當前提交
105ff00224
共有 3 個文件被更改,包括 57 次插入76 次删除
  1. 22 12
      bin/transport
  2. 1 1
      transport/__init__.py
  3. 34 63
      transport/sql.py

+ 22 - 12
bin/transport

@@ -75,10 +75,10 @@ class Post(Process):
 		_info = {"values":self.rows} if 'couch' in self.PROVIDER else self.rows	
 		ltypes = self.rows.dtypes.values
 		columns = self.rows.dtypes.index.tolist()
-		if not self.writer.has() :
+		# if not self.writer.has() :
 
 			
-			self.writer.make(fields=columns)
+			# self.writer.make(fields=columns)
 			# self.log(module='write',action='make-table',input={"name":self.writer.table})
 		for name in columns :
 			if _info[name].dtype in ['int32','int64','int','float','float32','float64'] :
@@ -86,7 +86,7 @@ class Post(Process):
 			else:
 				value = ''
 			_info[name] = _info[name].fillna(value)
-		print (_info)
+		
 		self.writer.write(_info)
 		self.writer.close()
 
@@ -94,6 +94,7 @@ class Post(Process):
 class ETL (Process):
 	def __init__(self,**_args):
 		super().__init__()
+		
 		self.name 	= _args['id']
 		if 'provider' not in _args['source'] :
 			#@deprecate
@@ -133,18 +134,24 @@ class ETL (Process):
 			
 			
 			self.log(module='write',action='partitioning')
-			rows = np.array_split(np.arange(idf.shape[0]),self.JOB_COUNT)
+			rows = np.array_split(np.arange(0,idf.shape[0]),self.JOB_COUNT)
+			
 			#
 			# @TODO: locks
-			for i in rows :
-				_id = 'segment #'.join([str(rows.index(i)),self.name])
-				segment = idf.loc[i,:] #.to_dict(orient='records')
+			for i in np.arange(self.JOB_COUNT) :
+				print ()
+				print (i)
+				_id = 'segment # '.join([str(i),' ',self.name])
+				indexes = rows[i]
+				segment = idf.loc[indexes,:].copy() #.to_dict(orient='records')
 				proc = Post(target = self._oargs,rows = segment,name=_id)
 				self.jobs.append(proc)
 				proc.start()
 
-			self.log(module='write',action='working ...',name=self.name)
-			
+				self.log(module='write',action='working',segment=_id)
+			# while poc :
+			# 	proc = [job for job in proc if job.is_alive()]
+			# 	time.sleep(1)
 		except Exception as e:
 			print (e)
 		
@@ -168,13 +175,16 @@ if __name__ == '__main__' :
 		if 'source' in SYS_ARGS :
 			_config['source'] = {"type":"disk.DiskReader","args":{"path":SYS_ARGS['source'],"delimiter":","}}
 
-		_config['jobs']  = 10 if 'jobs' not in SYS_ARGS else int(SYS_ARGS['jobs'])
+		_config['jobs']  = 3 if 'jobs' not in SYS_ARGS else int(SYS_ARGS['jobs'])
 		etl = ETL (**_config)
-		if not index :
+		if index is None:	
 			
 			etl.start()
 			procs.append(etl)
-		if index and _info.index(_config) == index :
+		
+		elif _info.index(_config) == index :
+			
+			# print (_config)
 			procs = [etl]
 			etl.start()
 			break

+ 1 - 1
transport/__init__.py

@@ -162,7 +162,7 @@ def instance(**_args):
 			if provider not in ['mongodb','couchdb','bigquery'] :
 				uri = ''.join([provider,"://",account,host,'/',database])
 				
-				e = sqlalchemy.create_engine (uri)
+				e = sqlalchemy.create_engine (uri,future=True)
 				args['sqlalchemy'] = e 
 			#
 			# @TODO: Include handling of bigquery with SQLAlchemy

+ 34 - 63
transport/sql.py

@@ -21,7 +21,7 @@ else:
 import json
 from google.oauth2 import service_account
 from google.cloud import bigquery as bq
-from multiprocessing import Lock
+from multiprocessing import Lock, RLock
 import pandas as pd
 import numpy as np
 import nzpy as nz   #--- netezza drivers
@@ -30,7 +30,7 @@ import os
 
 
 class SQLRW :
-   
+    lock = RLock()
     DRIVERS  = {"postgresql":pg,"redshift":pg,"mysql":my,"mariadb":my,"netezza":nz}
     REFERENCE = {
         "netezza":{"port":5480,"handler":nz,"dtype":"VARCHAR(512)"},
@@ -71,7 +71,7 @@ class SQLRW :
         # _handler = SQLWriter.REFERENCE[_provider]['handler']
         _handler        = _args['driver']  #-- handler to the driver
         self._dtype     = _args['default']['type'] if 'default' in _args and 'type' in _args['default'] else 'VARCHAR(256)'
-        self._provider  = _args['provider']
+        # self._provider  = _args['provider']
         # self._dtype = SQLWriter.REFERENCE[_provider]['dtype'] if 'dtype' not in _args else _args['dtype']
         # self._provider = _provider
         if _handler == nz :
@@ -173,7 +173,7 @@ class SQLWriter(SQLRW,Writer):
         # In the advent that data typing is difficult to determine we can inspect and perform a default case
         # This slows down the process but improves reliability of the data
         # NOTE: Proper data type should be set on the target system if their source is unclear.
-        self._inspect = False if 'inspect' not in _args else _args['inspect']
+        
         self._cast = False if 'cast' not in _args else _args['cast']
         
     def init(self,fields=None):
@@ -244,78 +244,49 @@ class SQLWriter(SQLRW,Writer):
         #     #
         #     # We are assuming 2 cases i.e dict or pd.DataFrame
         #     info = [info]  if type(info) == dict else info.values.tolist()       
-        cursor = self.conn.cursor()
+        
         try:
             table = self._tablename(self.table)
             _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",table) #.replace(":table",self.table).replace(":fields",_fields)
-            if self._inspect :
-                for _row in info :
-                    fields = list(_row.keys())
-                    if self._cast == False :
-                        values = ",".join(_row.values())
-                    else:
-                        # values = "'"+"','".join([str(value) for value in _row.values()])+"'"
-                        values = [",".join(["%(",name,")s"]) for name in _row.keys()]
-                    
-                    # values = [ "".join(["'",str(_row[key]),"'"]) if np.nan(_row[key]).isnumeric() else str(_row[key]) for key in _row]
-                    # print (values)
-                    query = _sql.replace(":fields",",".join(fields)).replace(":values",values)
-                    if type(info) == pd.DataFrame :
-                        _values = info.values.tolist()
-                    elif type(info) == list and type(info[0]) == dict:
-                        print ('........')
-                        _values = [tuple(item.values()) for item in info]
-                    else:
-                        _values = info;
-                    cursor.execute(query,_values)
-                
-
-                pass
+           
+            if type(info) == list :
+                _info = pd.DataFrame(info)
+            elif type(info) == dict :
+                _info = pd.DataFrame([info])
             else:
-                
-                # _sql = _sql.replace(":fields",_fields)
-                # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
-                # _sql = _sql.replace("(:fields)","")
-                
-                # _sql = _sql.replace(":values",values)
-                # if type(info) == pd.DataFrame :
-                #     _info = info[self.fields].values.tolist()
-                    
-                # elif  type(info) == dict :
-                #     _info = info.values()
-                # else:
-                #     # _info = []
-
-                #     _info = pd.DataFrame(info)[self.fields].values.tolist()
-                    # _info = pd.DataFrame(info).to_dict(orient='records')
-                if type(info) == list :
-                    _info = pd.DataFrame(info)
-                elif type(info) == dict :
-                    _info = pd.DataFrame([info])
-                else:
-                    _info = pd.DataFrame(info)
+                _info = pd.DataFrame(info)
+        
             
+            if _info.shape[0] == 0 :
                 
-                if self._engine :
-                    # pd.to_sql(_info,self._engine)
-                    
-                    rows = _info.to_sql(table,self._engine,schema=self.schema,if_exists='append',index=False)
-                    
+                return
+            SQLRW.lock.acquire()
+            if self._engine is not None:
+                # pd.to_sql(_info,self._engine)
+                if self.schema in ['',None] :
+                    rows = _info.to_sql(table,self._engine,if_exists='append',index=False)
                 else:
-                    _fields = ",".join(self.fields)
-                    _sql = _sql.replace(":fields",_fields)
-                    values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
-                    _sql = _sql.replace(":values",values)
-                        
-                    cursor.executemany(_sql,_info.values.tolist())  
-                # cursor.commit() 
+                    rows = _info.to_sql(self.table,self._engine,schema=self.schema,if_exists='append',index=False)
+                
+            else:
+                _fields = ",".join(self.fields)
+                _sql = _sql.replace(":fields",_fields)
+                values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
+                _sql = _sql.replace(":values",values)
+                cursor = self.conn.cursor()
+                cursor.executemany(_sql,_info.values.tolist())  
+                cursor.close()
+            # cursor.commit() 
             
             # self.conn.commit()
         except Exception as e:
             print(e)
             pass
         finally:
-            self.conn.commit()            
+            
+            if self._engine is None :
+                self.conn.commit()   
+            SQLRW.lock.release()         
             # cursor.close()
             pass
     def close(self):