浏览代码

Refactored, including population risk assessment

Steve L. Nyemba -- The Architect 6 年之前
父节点
当前提交
c3066408c9
共有 2 个文件被更改,包括 241 次插入87 次删除
  1. 97 4
      src/pandas_risk.py
  2. 144 83
      src/risk.py

+ 97 - 4
src/pandas_risk.py

@@ -22,16 +22,108 @@
 """
 import pandas as pd
 import numpy as np
-
+import time
 @pd.api.extensions.register_dataframe_accessor("deid")
 class deid :
     """
         This class is a deidentification class that will compute risk (marketer, prosecutor) given a pandas dataframe
     """
     def __init__(self,df):
-        self._df = df
+        self._df = df.fillna(' ')
+    def explore(self,**args):
+        """
+        This function will perform experimentation by performing a random policies (combinations of attributes)
+        This function is intended to explore a variety of policies and evaluate their associated risk.
+
+        @param pop|sample   data-frame with popublation reference
+        @param id       key field that uniquely identifies patient/customer ...
+        """
+        # id = args['id']
+        pop= args['pop'] if 'pop' in args else None
+        # if 'columns' in args :
+        #     cols = args['columns']
+        #     params = {"sample":args['data'],"cols":cols}
+        #     if pop is not None :
+        #         params['pop'] = pop
+        #     return self.evaluate(**params)
+        # else :
+        #
+        # Policies will be generated with a number of runs
+        #
+        RUNS = args['num_runs'] if 'num_runs' in args else 5
+        
+        sample = args['sample'] if 'sample' in args else pd.DataFrame(self._df)
+        
+        k = sample.columns.size -1 if 'field_count' not in args else int(args['field_count'])
+        columns = list(set(sample.columns.tolist()) - set([id]))
+        o = pd.DataFrame()
+        # pop = args['pop'] if 'pop' in args else None
+        for i in np.arange(RUNS):
+            n = np.random.randint(2,k)
+            
+            cols = np.random.choice(columns,n,replace=False).tolist()            
+            params = {'sample':sample,'cols':cols}
+            if pop is not None :
+                params['pop'] = pop
+            r = self.evaluate(**params)
+            #
+            # let's put the policy in place
+            p =  pd.DataFrame(1*sample.columns.isin(cols)).T
+            p.columns = sample.columns
+            o = o.append(r.join(p))
+            
+        o.index = np.arange(o.shape[0]).astype(np.int64)
+
+        return o
+    def evaluate(self,**args) :
+        """
+        This function will compute the marketer, if a population is provided it will evaluate the marketer risk relative to both the population and sample
+        @param smaple  data-frame with the data to be processed
+        @param policy   the columns to be considered.
+        @param pop      population dataset
+        @params flag    user defined flag (no computation use)
+        """
+        if (args and 'sample' not in args) or not args :
+            x_i = pd.DataFrame(self._df)
+        elif args and 'sample' in args :
+            x_i = args['sample']
+        if (args and 'cols' not in args) or not args :
+            cols = x_i.columns.tolist()
+            # cols = self._df.columns.tolist()
+        elif args and 'cols' in args :
+            cols = args['cols']
+        flag = args['flag'] if 'flag' in args else 'UNFLAGGED'
+        # if args and 'sample' in args :
+            
+        #     x_i     = pd.DataFrame(self._df)
+        # else :
+        #     cols    = args['cols'] if 'cols' in args else self._df.columns.tolist()
+        # x_i     = x_i.groupby(cols,as_index=False).size().values 
+        x_i_values = x_i.groupby(cols,as_index=False).size().values
+        SAMPLE_GROUP_COUNT = x_i_values.size
+        SAMPLE_FIELD_COUNT = len(cols)
+        SAMPLE_POPULATION  = x_i_values.sum()
+        
+        SAMPLE_MARKETER    = SAMPLE_GROUP_COUNT / np.float64(SAMPLE_POPULATION)
+        SAMPLE_PROSECUTOR  = 1/ np.min(x_i_values).astype(np.float64)
+        if 'pop' in args :
+            Yi = args['pop']            
+            y_i= pd.DataFrame({"group_size":Yi.groupby(cols,as_index=False).size()}).reset_index()
+            # y_i['group'] = pd.DataFrame({"group_size":args['pop'].groupby(cols,as_index=False).size().values}).reset_index()
+            # x_i = pd.DataFrame({"group_size":x_i.groupby(cols,as_index=False).size().values}).reset_index()
+            x_i = pd.DataFrame({"group_size":x_i.groupby(cols,as_index=False).size()}).reset_index()
+            SAMPLE_RATIO = int(100 * x_i.size/args['pop'].shape[0])
+            r = pd.merge(x_i,y_i,on=cols,how='inner')
+            r['marketer'] = r.apply(lambda row: (row.group_size_x / np.float64(row.group_size_y)) /np.sum(x_i.group_size) ,axis=1)
+            r['sample %'] = np.repeat(SAMPLE_RATIO,r.shape[0])
+            r['tier'] = np.repeat(flag,r.shape[0])
+            r['sample marketer'] =  np.repeat(SAMPLE_MARKETER,r.shape[0])
+            r = r.groupby(['sample %','tier','sample marketer'],as_index=False).sum()[['sample %','marketer','sample marketer','tier']]
+        else:
+            r = pd.DataFrame({"marketer":[SAMPLE_MARKETER],"prosecutor":[SAMPLE_PROSECUTOR],"field_count":[SAMPLE_FIELD_COUNT],"group_count":[SAMPLE_GROUP_COUNT]})
+        return r
     
-    def risk(self,**args):
+    def _risk(self,**args):
         """
             @param  id          name of patient field            
             @params num_runs    number of runs (default will be 100)
@@ -50,7 +142,7 @@ class deid :
         k = len(columns)
         N = self._df.shape[0]
         tmp = self._df.fillna(' ')
-        np.random.seed(1)
+        np.random.seed(int(time.time()) )
         for i in range(0,num_runs) :
             
             #
@@ -85,6 +177,7 @@ class deid :
                     [
                         {
                             "group_count":x_.size,
+                            
                             "patient_count":N,
                             "field_count":n,
                             "marketer": x_.size / np.float64(np.sum(x_)),

+ 144 - 83
src/risk.py

@@ -146,7 +146,7 @@ class utils :
         
         return " ".join(SQL).replace(":fields"," , ".join(fields))
 
-class risk :
+class SQLRisk :
     """
         This class will handle the creation of an SQL query that computes marketer and prosecutor risk (for now)
     """
@@ -186,102 +186,163 @@ class risk :
     
 
         
+class UtilHandler :
+    def __init__(self,**args) :
+        """
+            @param path path to the service account file
+            @param dataset    input dataset name
+            @param key_field    key_field (e.g person_id)
+            @param key_table
 
+        """
+        self.path        = args['path']
+        self.client      = bq.Client.from_service_account_json(self.path)
+        dataset   = args['dataset']
+        self.key         = args['key_field'] 
 
-if 'action' in SYS_ARGS and  SYS_ARGS['action'] in ['create','compute','migrate'] :
+        self.mytools = utils(client = self.client)
+        self.tables  = self.mytools.get_tables(dataset=dataset,client=self.client,key=self.key)
+        index = [ self.tables.index(item) for item in self.tables if item['name'] == args['key_table']] [0]
+        if index != 0 :
+            first = self.tables[0]
+            aux = self.tables[index]
+            self.tables[0] = aux
+            self.tables[index] = first
+        if 'filter' in args :
+            self.tables = [item for item in self.tables if item['name'] in args['filter']]
 
-    path = SYS_ARGS['path']
-    client = bq.Client.from_service_account_json(path)
-    i_dataset = SYS_ARGS['i_dataset']
-    key = SYS_ARGS['key'] 
 
-    mytools = utils(client = client)
-    tables = mytools.get_tables(dataset=i_dataset,client=client,key=key)
-    # print len(tables)
-    # tables = tables[:6]
+    def create_table(self,**args):
+        """
+            @param path absolute filename to save the create statement
 
-    if SYS_ARGS['action'] == 'create' :
-        #usage:
-        #   create --i_dataset <in dataset> --key <patient id> --o_dataset <out dataset> --table <table|file> [--file] --path <bq JSON account file>
-        #
-        create_sql = mytools.get_sql(tables=tables,key=key) #-- The create statement
-        o_dataset = SYS_ARGS['o_dataset']
-        table = SYS_ARGS['table']
-        if 'file' in SYS_ARGS :
-            f = open(table+'.sql','w')
+        """
+        create_sql = self.mytools.get_sql(tables=self.tables,key=self.key) #-- The create statement
+        # o_dataset = SYS_ARGS['o_dataset']
+        # table = SYS_ARGS['table']
+        if 'path' in args:
+            f = open(args['path'],'w')
             f.write(create_sql)
             f.close()
-        else:
-            job = bq.QueryJobConfig()
-            job.destination = client.dataset(o_dataset).table(table)
-            job.use_query_cache = True
-            job.allow_large_results = True 
-            job.priority = 'BATCH'
-            job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY)
+        return create_sql
+    def migrate_tables(self,**args):
+        """
+            This function will migrate a table from one location to another
+            The reason for migration is to be able to reduce a candidate table to only represent a patient by her quasi-identifiers.
+            @param dataset  target dataset
+        """
+        o_dataset = args['dataset'] if 'dataset' in args else None
+        p = []
+        for table in self.tables:
+            sql = " ".join(["SELECT ",",".join(table['fields']) ," FROM (",self.mytools.get_filtered_table(table,self.key),") as ",table['name']])        
+            p.append(sql)
+            if o_dataset :
+                job = bq.QueryJobConfig()
+                job.destination = self.client.dataset(o_dataset).table(table['name'])
+                job.use_query_cache = True
+                job.allow_large_results = True 
+                job.priority = 'INTERACTIVE'
+                job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY)
 
-            r = client.query(create_sql,location='US',job_config=job) 
+                r = self.client.query(sql,location='US',job_config=job) 
+
+                print [table['full_name'],' ** ',r.job_id,' ** ',r.state]
+        return p
+
+# if 'action' in SYS_ARGS and  SYS_ARGS['action'] in ['create','compute','migrate'] :
+
+#     path = SYS_ARGS['path']
+#     client = bq.Client.from_service_account_json(path)
+#     i_dataset = SYS_ARGS['i_dataset']
+#     key = SYS_ARGS['key'] 
+
+#     mytools = utils(client = client)
+#     tables = mytools.get_tables(dataset=i_dataset,client=client,key=key)
+#     # print len(tables)
+#     # tables = tables[:6]
+
+#     if SYS_ARGS['action'] == 'create' :
+#         #usage:
+#         #   create --i_dataset <in dataset> --key <patient id> --o_dataset <out dataset> --table <table|file> [--file] --path <bq JSON account file>
+#         #
+#         create_sql = mytools.get_sql(tables=tables,key=key) #-- The create statement
+#         o_dataset = SYS_ARGS['o_dataset']
+#         table = SYS_ARGS['table']
+#         if 'file' in SYS_ARGS :
+#             f = open(table+'.sql','w')
+#             f.write(create_sql)
+#             f.close()
+#         else:
+#             job = bq.QueryJobConfig()
+#             job.destination = client.dataset(o_dataset).table(table)
+#             job.use_query_cache = True
+#             job.allow_large_results = True 
+#             job.priority = 'BATCH'
+#             job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY)
+
+#             r = client.query(create_sql,location='US',job_config=job) 
             
-            print [r.job_id,' ** ',r.state]
-    elif SYS_ARGS['action'] == 'migrate' :
-        #
-        #
+#             print [r.job_id,' ** ',r.state]
+#     elif SYS_ARGS['action'] == 'migrate' :
+#         #
+#         #
 
-        o_dataset = SYS_ARGS['o_dataset']
-        for table in tables:
-            sql = " ".join(["SELECT ",",".join(table['fields']) ," FROM (",mytools.get_filtered_table(table,key),") as ",table['name']])
-            print ""
-            print sql
-            print ""
-            # job = bq.QueryJobConfig()
-            # job.destination = client.dataset(o_dataset).table(table['name'])
-            # job.use_query_cache = True
-            # job.allow_large_results = True 
-            # job.priority = 'INTERACTIVE'
-            # job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY)
+#         o_dataset = SYS_ARGS['o_dataset']
+#         for table in tables:
+#             sql = " ".join(["SELECT ",",".join(table['fields']) ," FROM (",mytools.get_filtered_table(table,key),") as ",table['name']])
+#             print ""
+#             print sql
+#             print ""
+#             # job = bq.QueryJobConfig()
+#             # job.destination = client.dataset(o_dataset).table(table['name'])
+#             # job.use_query_cache = True
+#             # job.allow_large_results = True 
+#             # job.priority = 'INTERACTIVE'
+#             # job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY)
 
-            # r = client.query(sql,location='US',job_config=job) 
+#             # r = client.query(sql,location='US',job_config=job) 
             
-            # print [table['full_name'],' ** ',r.job_id,' ** ',r.state]
+#             # print [table['full_name'],' ** ',r.job_id,' ** ',r.state]
 
 
-        pass
-    else:
-        #
-        #
-        tables  = [tab for tab in tables if tab['name'] == SYS_ARGS['table'] ]  
-        limit   = int(SYS_ARGS['limit']) if 'limit' in SYS_ARGS else 1
-        if tables :            
-            risk= risk()
-            df  = pd.DataFrame()
-            dfs = pd.DataFrame()
-            np.random.seed(1)
-            for i in range(0,limit) :
-                r = risk.get_sql(key=SYS_ARGS['key'],table=tables[0])
-                sql = r['sql']
-                dfs = dfs.append(r['stream'],sort=True)
-                df = df.append(pd.read_gbq(query=sql,private_key=path,dialect='standard').join(dfs))
-                # df = df.join(dfs,sort=True)
-                df.to_csv(SYS_ARGS['table']+'.csv')
-                # dfs.to_csv(SYS_ARGS['table']+'_stream.csv') 
-                print [i,' ** ',df.shape[0],pd.DataFrame(r['stream']).shape]
-                time.sleep(2)
+#         pass
+#     else:
+#         #
+#         #
+#         tables  = [tab for tab in tables if tab['name'] == SYS_ARGS['table'] ]  
+#         limit   = int(SYS_ARGS['limit']) if 'limit' in SYS_ARGS else 1
+#         if tables :            
+#             risk= risk()
+#             df  = pd.DataFrame()
+#             dfs = pd.DataFrame()
+#             np.random.seed(1)
+#             for i in range(0,limit) :
+#                 r = risk.get_sql(key=SYS_ARGS['key'],table=tables[0])
+#                 sql = r['sql']
+#                 dfs = dfs.append(r['stream'],sort=True)
+#                 df = df.append(pd.read_gbq(query=sql,private_key=path,dialect='standard').join(dfs))
+#                 # df = df.join(dfs,sort=True)
+#                 df.to_csv(SYS_ARGS['table']+'.csv')
+#                 # dfs.to_csv(SYS_ARGS['table']+'_stream.csv') 
+#                 print [i,' ** ',df.shape[0],pd.DataFrame(r['stream']).shape]
+#                 time.sleep(2)
                 
     
-else:
-    print 'ERROR'
-    pass
+# else:
+#     print 'ERROR'
+#     pass
 
-# r = risk(path='/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json', i_dataset='raw',o_dataset='risk_o',o_table='mo')
-# tables = r.get_tables('raw','person_id')
-# sql = r.get_sql(tables=tables[:3],key='person_id')
-# #
-# # let's post this to a designated location
-# #
-# f = open('foo.sql','w')
-# f.write(sql)
-# f.close()
-# r.get_sql(tables=tables,key='person_id')
-# p = r.compute()
-# print p
-# p.to_csv("risk.csv")
-# r.write('foo.sql')
+# # r = risk(path='/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json', i_dataset='raw',o_dataset='risk_o',o_table='mo')
+# # tables = r.get_tables('raw','person_id')
+# # sql = r.get_sql(tables=tables[:3],key='person_id')
+# # #
+# # # let's post this to a designated location
+# # #
+# # f = open('foo.sql','w')
+# # f.write(sql)
+# # f.close()
+# # r.get_sql(tables=tables,key='person_id')
+# # p = r.compute()
+# # print p
+# # p.to_csv("risk.csv")
+# # r.write('foo.sql')