Browse Source

bug fix: sqlwriter

Steve Nyemba 3 years ago
parent
commit
10adde7a08
2 changed files with 42 additions and 14 deletions
  1. 38 13
      bin/transport
  2. 4 1
      transport/sql.py

+ 38 - 13
bin/transport

@@ -40,8 +40,7 @@ class Post(Process):
 		self.writer = 	transport.factory.instance(**args['target'])
 		self.rows 	=	 args['rows']
 	def run(self):
-		_info = {"values":self.rows} if 'couch' in self.PROVIDER else self.rows
-		
+		_info = {"values":self.rows} if 'couch' in self.PROVIDER else self.rows		
 		self.writer.write(_info)
 		self.writer.close()
 
@@ -53,6 +52,7 @@ class ETL (Process):
 		self.reader = transport.factory.instance(**_args['source'])
 		self._oargs = _args['target'] #transport.factory.instance(**_args['target'])
 		self.JOB_COUNT =  _args['jobs']
+		self.jobs = []
 		# self.logger = transport.factory.instance(**_args['logger'])
 	def log(self,**_args) :
 		_args['name']  = self.name
@@ -61,7 +61,7 @@ class ETL (Process):
 		idf = self.reader.read()
 		idf = pd.DataFrame(idf)
 		idf.columns = [str(name).replace("b'",'').replace("'","").strip() for name in idf.columns.tolist()]
-		self.log(rows=idf.shape[0],cols=idf.shape[1])
+		self.log(rows=idf.shape[0],cols=idf.shape[1],jobs=self.JOB_COUNT)
 
 		#
 		# writing the data to a designated data source 
@@ -69,27 +69,52 @@ class ETL (Process):
 		try:
 			self.log(module='write',action='partitioning')
 			rows = np.array_split(np.arange(idf.shape[0]),self.JOB_COUNT)
-			jobs = []
+			
 			for i in rows :
+				_id = 'segment #'.join([str(rows.index(i)),self.name])
 				segment = idf.loc[i,:] #.to_dict(orient='records')
-				proc = Post(target = self._oargs,rows = segment)
-				jobs.append(proc)
+				proc = Post(target = self._oargs,rows = segment,name=_id)
+				self.jobs.append(proc)
 				proc.start()
 
-			self.log(module='write',action='working ...')
-			while jobs :
-				jobs = [proc for proc in jobs if proc.is_alive()]
-				time.sleep(2)
-			self.log(module='write',action='completed')
+			self.log(module='write',action='working ...',name=self.name)
+			
 		except Exception as e:
 			print (e)
+		
+	def is_done(self):
+		self.jobs = [proc for proc in self.jobs if proc.is_alive()]
+		return len(self.jobs) == 0
+def apply(_args) :
+	"""
+	This function will apply a set of commands against a data-store. The expected structure is as follows :
+	{"store":...,"apply":[]}	
+	"""
+	handler = transport.factory.instance(**_args['store'])
+	for cmd in _args['apply'] :
+		handler.apply(cmd)
+	handler.close()
 if __name__ == '__main__' :
 	_info = json.loads(open (SYS_ARGS['config']).read())
-	
+	index = int(SYS_ARGS['index']) if 'index' in SYS_ARGS else None
+	procs = []
 	for _config in _info :
 		if 'source' in SYS_ARGS :
 			_config['source'] = {"type":"disk.DiskReader","args":{"path":SYS_ARGS['source'],"delimiter":","}}
 
 		_config['jobs']  = 10 if 'jobs' not in SYS_ARGS else int(SYS_ARGS['jobs'])
 		etl = ETL (**_config)
-		etl.start()
+		etl.start()
+		procs.append(etl)
+		if index and _info.index(_config) == index :
+			break
+	#
+	#
+	N = len(procs)
+	while procs :
+		procs = [thread for thread in procs if not thread.is_done()]
+		if len(procs) < N :
+			print (["Finished ",(N-len(procs)), " remaining ", len(procs)])
+			N = len(procs)
+		time.sleep(1)
+	print ("We're done !!")

+ 4 - 1
transport/sql.py

@@ -58,6 +58,7 @@ class SQLRW :
         # _handler = SQLWriter.DRIVERS[_args['provider']]
         _handler = SQLWriter.REFERENCE[_provider]['handler']
         self._dtype = SQLWriter.REFERENCE[_provider]['dtype'] if 'dtype' not in _args else _args['dtype']
+        self._provider = _provider
         if _handler == nz :
             _info['database'] = _info['dbname']
             _info['securityLevel'] = 0
@@ -199,11 +200,13 @@ class SQLWriter(SQLRW,Writer):
                 # _sql = _sql.replace(":fields",_fields)
                 # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
                 _sql = _sql.replace("(:fields)","")
-                values = ", ".join('?'*len(self.fields))
+                values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
                 _sql = _sql.replace(":values",values)
+                print (_sql)
                 
                 # for row in info :
                 #     values = ["'".join(["",value,""]) if not str(value).isnumeric() else value for value in row.values()]
+                
                 cursor.executemany(_sql,info)   
             
             # self.conn.commit()