Просмотр исходного кода

bug fix: prosecutor risk, marketer risk

Steve L. Nyemba -- The Architect 6 лет назад
Родитель
Сommit
140a4c4573
3 измененных файлов с 374 добавлено и 74 удалено
  1. 131 74
      notebooks/risk.ipynb
  2. 17 0
      src/params.py
  3. 226 0
      src/risk.py

+ 131 - 74
notebooks/risk.ipynb

@@ -2,15 +2,29 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 66,
+   "execution_count": 1,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "dev-deid-600@aou-res-deid-vumc-test.iam.gserviceaccount.com df0ac049-d5b6-416f-ab3c-6321eda919d6 2018-09-25 08:18:34.829000+00:00 DONE\n"
+     ]
+    }
+   ],
    "source": [
     "import pandas as pd\n",
     "import numpy as np\n",
     "from google.cloud import bigquery as bq\n",
     "\n",
-    "client = bq.Client.from_service_account_json('/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json')"
+    "client = bq.Client.from_service_account_json('/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json')\n",
+    "# pd.read_gbq(query=\"select * from raw.observation limit 10\",private_key='/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json')\n",
+    "jobs = client.list_jobs()\n",
+    "for job in jobs :\n",
+    "#     print dir(job)\n",
+    "    print job.user_email,job.job_id,job.started, job.state\n",
+    "    break"
    ]
   },
   {
@@ -25,7 +39,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 181,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -68,7 +82,7 @@
     "    else:\n",
     "        x_ = args['xi']\n",
     "    for xi in x_ :\n",
-    "        fields += (['.'.join([xi['name'],name]) for name in xi['fields'] if name != args['join']])\n",
+    "        fields += (['.'.join([xi['name'], name]) for name in xi['fields'] if name != args['join']])\n",
     "    return fields\n",
     "def generate_sql(**args):\n",
     "    \"\"\"\n",
@@ -97,7 +111,27 @@
     "            tmp.append(ON_SQL)\n",
     "        INNER_JOINS += [JOIN_SQL + \" AND \".join(tmp)]\n",
     "    return SQL + \" \".join(INNER_JOINS)\n",
-    "                \n",
+    "def get_final_sql(**args):\n",
+    "    xo = args['xo']\n",
+    "    xi = args['xi']\n",
+    "    join=args['join']\n",
+    "    prefix = args['prefix'] if 'prefix' in args else ''\n",
+    "    fields = get_fields (xo=xo,xi=xi,join=join)\n",
+    "    k = len(fields)\n",
+    "    n = np.random.randint(2,k) #-- number of fields to select\n",
+    "    i = np.random.randint(0,k,size=n)\n",
+    "    fields = [name for name in fields if fields.index(name) in i]\n",
+    "    base_sql = generate_sql(xo=xo,xi=xi,prefix)\n",
+    "    SQL = \"\"\"\n",
+    "        SELECT AVERAGE(count),size,n as selected_features,k as total_features\n",
+    "        FROM(\n",
+    "            SELECT COUNT(*) as count,count(:join) as pop,sum(:n) as N,sum(:k) as k,:fields\n",
+    "            FROM (:sql)\n",
+    "        GROUP BY :fields\n",
+    "        ) \n",
+    "        order by 1\n",
+    "        \n",
+    "    \"\"\".replace(\":sql\",base_sql)\n",
     "#     sql = \"SELECT :fields FROM :xo.name INNER JOIN :xi.name ON :xi.name.:xi.y = :xo.y \"\n",
     "#     fields = \",\".join(get_fields(xo=xi,xi=xi,join=xi['y']))\n",
     "    \n",
@@ -111,24 +145,39 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 183,
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "xo = {\"name\":\"person\",\"fields\":['person_id','date_of_birth','race','value_as_number']}\n",
+    "xi = [{\"name\":\"measurement\",\"fields\":['person_id','value_as_number','value_source_value']}] #,{\"name\":\"observation\",\"fields\":[\"person_id\",\"value_as_string\",\"observation_source_value\"]}]\n",
+    "# generate_sql(xo=xo,xi=xi,join=\"person_id\",prefix='raw')\n",
+    "fields = get_fields(xo=xo,xi=xi,join='person_id')\n",
+    "ofields = list(fields)\n",
+    "k = len(fields)\n",
+    "n = np.random.randint(2,k) #-- number of fields to select\n",
+    "i = np.random.randint(0,k,size=n)\n",
+    "fields = [name for name in fields if fields.index(name) in i]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "'SELECT :fields FROM raw.person INNER JOIN raw.measurement ON measurement.person_id = person.person_id'"
+       "['person.race', 'person.value_as_number', 'measurement.value_source_value']"
       ]
      },
-     "execution_count": 183,
+     "execution_count": 34,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "xo = {\"name\":\"person\",\"fields\":['person_id','date_of_birth','race']}\n",
-    "xi = [{\"name\":\"measurement\",\"fields\":['person_id','value_as_number','value_source_value']}] #,{\"name\":\"observation\",\"fields\":[\"person_id\",\"value_as_string\",\"observation_source_value\"]}]\n",
-    "generate_sql(xo=xo,xi=xi,join=\"person_id\",prefix='raw')"
+    "fields\n"
    ]
   },
   {
@@ -179,69 +228,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 111,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "[u'condition_occurrence.condition_occurrence_id',\n",
-       " u'condition_occurrence.person_id',\n",
-       " u'condition_occurrence.condition_concept_id',\n",
-       " u'condition_occurrence.condition_start_date',\n",
-       " u'condition_occurrence.condition_start_datetime',\n",
-       " u'condition_occurrence.condition_end_date',\n",
-       " u'condition_occurrence.condition_end_datetime',\n",
-       " u'condition_occurrence.condition_type_concept_id',\n",
-       " u'condition_occurrence.stop_reason',\n",
-       " u'condition_occurrence.provider_id',\n",
-       " u'condition_occurrence.visit_occurrence_id',\n",
-       " u'condition_occurrence.condition_source_value',\n",
-       " u'condition_occurrence.condition_source_concept_id',\n",
-       " u'death.death_date',\n",
-       " u'death.death_datetime',\n",
-       " u'death.death_type_concept_id',\n",
-       " u'death.cause_concept_id',\n",
-       " u'death.cause_source_value',\n",
-       " u'death.cause_source_concept_id',\n",
-       " u'device_exposure.device_exposure_id',\n",
-       " u'device_exposure.device_concept_id',\n",
-       " u'device_exposure.device_exposure_start_date',\n",
-       " u'device_exposure.device_exposure_start_datetime',\n",
-       " u'device_exposure.device_exposure_end_date',\n",
-       " u'device_exposure.device_exposure_end_datetime',\n",
-       " u'device_exposure.device_type_concept_id',\n",
-       " u'device_exposure.unique_device_id',\n",
-       " u'device_exposure.quantity',\n",
-       " u'device_exposure.provider_id',\n",
-       " u'device_exposure.visit_occurrence_id',\n",
-       " u'device_exposure.device_source_value',\n",
-       " u'device_exposure.device_source_concept_id',\n",
-       " u'drug_exposure.drug_exposure_id',\n",
-       " u'drug_exposure.drug_concept_id',\n",
-       " u'drug_exposure.drug_exposure_start_date',\n",
-       " u'drug_exposure.drug_exposure_start_datetime',\n",
-       " u'drug_exposure.drug_exposure_end_date',\n",
-       " u'drug_exposure.drug_exposure_end_datetime',\n",
-       " u'drug_exposure.drug_type_concept_id',\n",
-       " u'drug_exposure.stop_reason',\n",
-       " u'drug_exposure.refills',\n",
-       " u'drug_exposure.quantity',\n",
-       " u'drug_exposure.days_supply',\n",
-       " u'drug_exposure.sig',\n",
-       " u'drug_exposure.route_concept_id',\n",
-       " u'drug_exposure.effective_drug_dose',\n",
-       " u'drug_exposure.dose_unit_concept_id',\n",
-       " u'drug_exposure.lot_number',\n",
-       " u'drug_exposure.provider_id',\n",
-       " u'drug_exposure.visit_occurrence_id',\n",
-       " u'drug_exposure.drug_source_value',\n",
-       " u'drug_exposure.drug_source_concept_id',\n",
-       " u'drug_exposure.route_source_value',\n",
-       " u'drug_exposure.dose_unit_source_value']"
+       "array([1, 3, 0, 0])"
       ]
      },
-     "execution_count": 111,
+     "execution_count": 6,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -250,12 +246,7 @@
     "#\n",
     "# find every table with person id at the very least or a subset of fields\n",
     "#\n",
-    "info = get_tables(client,'raw',['person_id'])\n",
-    "# get_fields(xo=names[0],xi=names[1:4],join='person_id')\n",
-    "\n",
-    "# q = ['person_id']\n",
-    "# pairs = list(itertools.combinations(names,len(names)))\n",
-    "# pairs[0]"
+    "np.random.randint(0,4,size=4)"
    ]
   },
   {
@@ -287,6 +278,72 @@
     "x_ = 1"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x_ = pd.DataFrame({\"group\":[1,1,1,1,1], \"size\":[2,1,1,1,1]})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>size</th>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>group</th>\n",
+       "      <th></th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>1.2</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "       size\n",
+       "group      \n",
+       "1       1.2"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x_.groupby(['group']).mean()\n"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,

+ 17 - 0
src/params.py

@@ -0,0 +1,17 @@
+import sys
+SYS_ARGS={}
+if len(sys.argv) > 1 :
+    N = len(sys.argv)
+    for i in range(1,N) :
+        value = 1
+        
+        if sys.argv[i].startswith('--') :
+            key = sys.argv[i].replace('-','')
+            
+            if i + 1 < N and not sys.argv[i+1].startswith('--') :
+                value = sys.argv[i + 1].strip()
+            SYS_ARGS[key] = value
+            i += 2
+        elif 'action' not in SYS_ARGS:
+            SYS_ARGS['action'] = sys.argv[i].strip()
+        

+ 226 - 0
src/risk.py

@@ -0,0 +1,226 @@
+"""
+    Steve L. Nyemba & Brad Malin
+    Health Information Privacy Lab.
+
+    This code is proof of concept as to how risk is computed against a database (at least a schema).
+    The engine will read tables that have a given criteria (patient id) and generate a dataset by performing joins.
+    Because joins are process intensive we decided to add a limit to the records pulled.
+    
+    TL;DR:
+        This engine generates a dataset and computes risk (marketer and prosecutor)    
+    Assumptions:
+        - We assume tables that reference patients will name the keys identically (best practice). This allows us to be able to leverage data store's that don't support referential integrity
+        
+    Usage :
+        
+    Limitations
+        - It works against bigquery for now
+        @TODO:    
+            - Need to write a transport layer (database interface)
+            - Support for referential integrity, so one table can be selected and a dataset derived given referential integrity
+            - Add support for journalist risk
+"""
+import pandas as pd
+import numpy as np
+from google.cloud import bigquery as bq
+import time
+from params import SYS_ARGS
+class utils :
+    """
+        This class is a utility class that will generate SQL-11 compatible code in order to run the risk assessment
+        
+        @TODO: plugins for other data-stores
+    """
+    def __init__(self,**args):
+        # self.path = args['path']
+        self.client = args['client']
+    
+    def get_tables(self,**args): #id,key='person_id'):
+        """
+            This function returns a list of tables given a key. The key is the name of the field that uniquely designates a patient/person
+            in the database. The list of tables are tables that can be joined given the provided field.
+
+            @param key  name of the patient field
+            @param dataset   dataset name
+            @param client   initialized bigquery client ()
+            @return [{name,fields:[],row_count}]
+        """
+        dataset = args['dataset']
+        client  = args['client']
+        key     = args['key']
+        r = []
+        ref = client.dataset(dataset)
+        tables = list(client.list_tables(ref))
+        for table in tables :
+            
+            if table.table_id.strip() in ['people_seed']:
+                print ' skiping ...'
+                continue
+            ref = table.reference
+            table = client.get_table(ref)
+            schema = table.schema
+            rows = table.num_rows
+            if rows == 0 :
+                continue
+            names = [f.name for f in schema]
+            x = list(set(names) & set([key]))
+            if x  :
+                full_name = ".".join([dataset,table.table_id])
+                r.append({"name":table.table_id,"fields":names,"row_count":rows,"full_name":full_name})
+        return r
+    def get_field_name(self,alias,field_name,index):
+        """
+            This function will format the a field name given an index (the number of times it has occurred in projection)
+            The index is intended to avoid a "duplicate field" error (bigquery issue)
+
+            @param alias    alias of the table
+            @param field_name   name of the field to be formatted
+            @param index    the number of times the field appears in the projection
+        """
+        name = [alias,field_name]
+        if index > 0 :
+            return ".".join(name)+" AS :field_name:index".replace(":field_name",field_name).replace(":index",str(index))
+        else:
+            return ".".join(name)
+    def get_sql(self,**args):
+        """
+            This function will generate that will join a list of tables given a key and a limit of records
+            @param tables   list of tables
+            @param  key     key field to be used in the join. The assumption is that the field name is identical across tables (best practice!)
+            @param limit    a limit imposed, in case of ristrictions considering joins are resource intensive
+        """
+        tables  = args['tables'] 
+        key     = args['key']
+        limit   = args['limit'] if 'limit' in args else 300000
+        limit   = str(limit) 
+        SQL = [
+            """ 
+            SELECT :fields 
+            FROM 
+            """]
+        fields = []
+        prev_table = None
+        for table in tables :
+            name = table['full_name'] #".".join([self.i_dataset,table['name']])
+            alias= table['name']
+            index = tables.index(table)
+            sql_ = """ 
+                (select * from :name limit :limit) as :alias
+            """.replace(":limit",limit)
+            sql_ = sql_.replace(":name",name).replace(":alias",alias)
+            fields += [self.get_field_name(alias,field_name,index) for field_name in table['fields'] if field_name != key or  (field_name==key and  tables.index(table) == 0) ]
+            if tables.index(table) > 0 :
+                join = """
+                    INNER JOIN :sql ON :alias.:field = :prev_alias.:field
+                """.replace(":name",name)
+                join = join.replace(":alias",alias).replace(":field",key).replace(":prev_alias",prev_alias)
+                sql_ = join.replace(":sql",sql_)
+                # sql_ = " ".join([sql_,join])
+            SQL += [sql_]
+            if index == 0:
+                prev_alias = str(alias)
+        
+        return " ".join(SQL).replace(":fields"," , ".join(fields))
+
+class risk :
+    """
+        This class will handle the creation of an SQL query that computes marketer and prosecutor risk (for now)
+    """
+    def __init__(self):
+        pass
+    def get_sql(self,**args) :
+        """
+            This function returns the SQL Query that will compute marketer and prosecutor risk
+            @param key      key fields (patient identifier)
+            @param table    table that is subject of the computation
+        """
+        key     = args['key']
+        table   = args['table']
+        fields  = list(set(table['fields']) - set([key]))
+        #-- We need to select n-fields max 64
+        k = len(fields)        
+        n = np.random.randint(2,24)  #-- how many random fields are we processing
+        ii = np.random.choice(k,n,replace=False)
+        fields = list(np.array(fields)[ii])
+
+        sql = """
+            SELECT COUNT(g_size) as group_count, SUM(g_size) as patient_count, COUNT(g_size)/SUM(g_size) as marketer, 1/ MIN(g_size) as prosecutor
+            FROM (
+                SELECT COUNT(*) as g_size,:key,:fields
+                FROM :full_name
+                GROUP BY :key,:fields
+            )
+        """.replace(":fields", ",".join(fields)).replace(":full_name",table['full_name']).replace(":key",key).replace(":n",str(n))
+        return sql
+    
+
+        
+
+
+if 'action' in SYS_ARGS and  SYS_ARGS['action'] in ['create','compute'] :
+
+    path = SYS_ARGS['path']
+    client = bq.Client.from_service_account_json(path)
+    i_dataset = SYS_ARGS['i_dataset']
+    key = SYS_ARGS['key'] 
+
+    mytools = utils(client = client)
+    tables = mytools.get_tables(dataset=i_dataset,client=client,key=key)
+    # print len(tables)
+    # tables = tables[:6]
+
+    if SYS_ARGS['action'] == 'create' :
+        #usage:
+        #   create --i_dataset <in dataset> --key <patient id> --o_dataset <out dataset> --table <table|file> [--file] --path <bq JSON account file>
+        #
+        create_sql = mytools.get_sql(tables=tables,key=key) #-- The create statement
+        o_dataset = SYS_ARGS['o_dataset']
+        table = SYS_ARGS['table']
+        if 'file' in SYS_ARGS :
+            f = open(table+'.sql','w')
+            f.write(create_sql)
+            f.close()
+        else:
+            job = bq.QueryJobConfig()
+            job.destination = client.dataset(o_dataset).table(table)
+            job.use_query_cache = True
+            job.allow_large_results = True 
+            job.priority = 'BATCH'
+            job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY)
+
+            r = client.query(create_sql,location='US',job_config=job) 
+            
+            print [r.job_id,' ** ',r.state]
+    else:
+        #
+        #
+        tables = [tab for tab in tables if tab['name'] == SYS_ARGS['table'] ]  
+        if tables :            
+            risk = risk()
+            df = pd.DataFrame()
+            for i in range(0,10) :
+                sql = risk.get_sql(key=SYS_ARGS['key'],table=tables[0])
+                df = df.append(pd.read_gbq(query=sql,private_key=path,dialect='standard'))
+                df.to_csv(SYS_ARGS['table']+'.csv')
+                print [i,' ** ',df.shape[0]]
+                time.sleep(2)
+                
+    pass
+else:
+    print 'ERROR'
+    pass
+
+# r = risk(path='/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json', i_dataset='raw',o_dataset='risk_o',o_table='mo')
+# tables = r.get_tables('raw','person_id')
+# sql = r.get_sql(tables=tables[:3],key='person_id')
+# #
+# # let's post this to a designated location
+# #
+# f = open('foo.sql','w')
+# f.write(sql)
+# f.close()
+# r.get_sql(tables=tables,key='person_id')
+# p = r.compute()
+# print p
+# p.to_csv("risk.csv")
+# r.write('foo.sql')