Bläddra i källkod

bug fix: transport cli and write function for sql

Steve Nyemba 3 år sedan
förälder
incheckning
cfc683c1b3
2 ändrade filer med 42 tillägg och 20 borttagningar
  1. 26 4
      bin/transport
  2. 16 16
      transport/sql.py

+ 26 - 4
bin/transport

@@ -68,11 +68,25 @@ class Post(Process):
 		#
 		# If the table doesn't exists maybe create it ?
 		#
-		self.rows 	=	 args['rows']
+		self.rows 	=	 args['rows'].fillna('')
+		
 		
 	def run(self):
 		_info = {"values":self.rows} if 'couch' in self.PROVIDER else self.rows	
-		
+		ltypes = self.rows.dtypes.values
+		columns = self.rows.dtypes.index.tolist()
+		if not self.writer.has() :
+
+			
+			self.writer.make(fields=columns)
+			# self.log(module='write',action='make-table',input={"name":self.writer.table})
+		for name in columns :
+			if _info[name].dtype in ['int32','int64','int','float','float32','float64'] :
+				value = 0
+			else:
+				value = ''
+			_info[name] = _info[name].fillna(value)
+		print (_info)
 		self.writer.write(_info)
 		self.writer.close()
 
@@ -107,6 +121,8 @@ class ETL (Process):
 		else:
 			idf = self.reader.read() 
 		idf = pd.DataFrame(idf)		
+		# idf = idf.replace({np.nan: None}, inplace = True)
+
 		idf.columns = [str(name).replace("b'",'').replace("'","").strip() for name in idf.columns.tolist()]
 		self.log(rows=idf.shape[0],cols=idf.shape[1],jobs=self.JOB_COUNT)
 
@@ -114,6 +130,8 @@ class ETL (Process):
 		# writing the data to a designated data source 
 		#
 		try:
+			
+			
 			self.log(module='write',action='partitioning')
 			rows = np.array_split(np.arange(idf.shape[0]),self.JOB_COUNT)
 			#
@@ -152,9 +170,13 @@ if __name__ == '__main__' :
 
 		_config['jobs']  = 10 if 'jobs' not in SYS_ARGS else int(SYS_ARGS['jobs'])
 		etl = ETL (**_config)
-		etl.start()
-		procs.append(etl)
+		if not index :
+			
+			etl.start()
+			procs.append(etl)
 		if index and _info.index(_config) == index :
+			procs = [etl]
+			etl.start()
 			break
 	#
 	#

+ 16 - 16
transport/sql.py

@@ -93,7 +93,7 @@ class SQLRW :
         found = False
         try:
             
-            table = self._tablename(_args['table'])
+            table = self._tablename(_args['table'])if 'table' in _args else self._tablename(self.table)
             sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
             if self._engine :
                 _conn = self._engine.connect()
@@ -192,9 +192,9 @@ class SQLWriter(SQLRW,Writer):
             fields = _args['fields']  
             # table = self._tablename(self.table)          
             sql = " ".join(["CREATE TABLE",table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
-            print (sql)
+            
         else:
-            schema = _args['schema']
+            schema = _args['schema'] if 'schema' in _args else ''
             
             _map = _args['map'] if 'map' in _args else {}
             sql = [] # ["CREATE TABLE ",_args['table'],"("]
@@ -214,7 +214,7 @@ class SQLWriter(SQLRW,Writer):
             cursor.execute(sql)
         except Exception as e :
             print (e)
-            print (sql)
+            # print (sql)
             pass
         finally:
             # cursor.close()
@@ -296,18 +296,18 @@ class SQLWriter(SQLRW,Writer):
                     _info = pd.DataFrame(info)
             
                 
-                # if self._engine :
-                #     # pd.to_sql(_info,self._engine)
-                #     print (_info.columns.tolist())
-                #     rows = _info.to_sql(table,self._engine,if_exists='append',index=False)
-                #     print ([rows])
-                # else:
-                _fields = ",".join(self.fields)
-                _sql = _sql.replace(":fields",_fields)
-                values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
-                _sql = _sql.replace(":values",values)
+                if self._engine :
+                    # pd.to_sql(_info,self._engine)
                     
-                cursor.executemany(_sql,_info.values.tolist())  
+                    rows = _info.to_sql(table,self._engine,if_exists='append',index=False)
+                    
+                else:
+                    _fields = ",".join(self.fields)
+                    _sql = _sql.replace(":fields",_fields)
+                    values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
+                    _sql = _sql.replace(":values",values)
+                        
+                    cursor.executemany(_sql,_info.values.tolist())  
                 # cursor.commit() 
             
             # self.conn.commit()
@@ -338,7 +338,7 @@ class BigQuery:
         :param table    name of the name WITHOUT including dataset
         :param sql      sql query to be pulled,
         """
-        table = _args['table']
+        table = _args['table'] 
         
         ref     = self.client.dataset(self.dataset).table(table)
         return self.client.get_table(ref).schema