diff --git a/csvkit/sql.py b/csvkit/sql.py index 0e809b03bbd7064645249b1272df2d34731bed2f..1aba5c45e1db567197262526c4fb468c28c0ed70 100644 --- a/csvkit/sql.py +++ b/csvkit/sql.py @@ -3,6 +3,8 @@ import datetime import six +import normality +from unidecode import unidecode from sqlalchemy import Column, MetaData, Table, create_engine from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, Integer, String, Time @@ -27,7 +29,12 @@ NULL_COLUMN_MAX_LENGTH = 32 SQL_INTEGER_MAX = 2147483647 SQL_INTEGER_MIN = -2147483647 -def make_column(column, no_constraints=False): + +def normalize_name(name): + return normality.slugify(unidecode(name), sep='_') + + +def make_column(column, no_constraints=False, normalize_columns=False): """ Creates a sqlalchemy column from a csvkit Column. """ @@ -66,7 +73,11 @@ def make_column(column, no_constraints=False): sql_column_kwargs['nullable'] = column.has_nulls() - return Column(column.name, sql_column_type(**sql_type_kwargs), **sql_column_kwargs) + column_name = column.name + if normalize_columns: + column_name = normalize_name(column.name) + + return Column(column_name, sql_column_type(**sql_type_kwargs), **sql_column_kwargs) def get_connection(connection_string): engine = create_engine(connection_string) @@ -74,7 +85,7 @@ def get_connection(connection_string): return engine, metadata -def make_table(csv_table, name='table_name', no_constraints=False, db_schema=None, metadata=None): +def make_table(csv_table, name='table_name', no_constraints=False, db_schema=None, normalize_columns=False, metadata=None): """ Creates a sqlalchemy table from a csvkit Table. """ @@ -84,7 +95,7 @@ def make_table(csv_table, name='table_name', no_constraints=False, db_schema=Non sql_table = Table(csv_table.name, metadata, schema=db_schema) for column in csv_table: - sql_table.append_column(make_column(column, no_constraints)) + sql_table.append_column(make_column(column, no_constraints, normalize_columns)) return sql_table diff --git a/csvkit/utilities/csvsql.py b/csvkit/utilities/csvsql.py index 4a3dd74748acda0d0d5a8d362563d0fb4fec63f7..25121873987c1d7504ab975a53f247bf5e5aaf1c 100644 --- a/csvkit/utilities/csvsql.py +++ b/csvkit/utilities/csvsql.py @@ -29,6 +29,8 @@ class CSVSQL(CSVKitUtility): help='In addition to creating the table, also insert the data into the table. Only valid when --db is specified.') self.argparser.add_argument('--tables', dest='table_names', help='Specify one or more names for the tables to be created. If omitted, the filename (minus extension) or "stdin" will be used.') + self.argparser.add_argument('-n', '--normalize-columns', dest='normalize_columns', action='store_true', + help='Normalize the headers before generating column names.') self.argparser.add_argument('--no-constraints', dest='no_constraints', action='store_true', help='Generate a schema without length limits or null checks. Useful when sampling big tables.') self.argparser.add_argument('--no-create', dest='no_create', action='store_true', @@ -115,6 +117,7 @@ class CSVSQL(CSVKitUtility): table_name, self.args.no_constraints, self.args.db_schema, + self.args.normalize_columns, metadata ) @@ -125,7 +128,7 @@ class CSVSQL(CSVKitUtility): # Insert data if do_insert and csv_table.count_rows() > 0: insert = sql_table.insert() - headers = csv_table.headers() + headers = [sql.normalize_name(h) for h in csv_table.headers()] conn.execute(insert, [dict(zip(headers, row)) for row in csv_table.to_rows()]) # Output SQL statements diff --git a/setup.py b/setup.py index e27e33d806c24a65aadf64af31b7d08e91670ce3..f9b34da5d5e7da05663826cfbff615eda1bd146a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,9 @@ install_requires = [ 'openpyxl==2.2.6', 'six>=1.6.1', 'python-dateutil==2.2', - 'dbf>=0.96.005' + 'dbf>=0.96.005', + 'Unidecode>=0.04.19', + 'normality>=0.2.4' ] if sys.version_info < (2, 7): diff --git a/tests/test_sql.py b/tests/test_sql.py index f7abcf0e05b5605b5e95eb14adeda0697e2ca39a..917c3026508746a2811aa1ece802b1c4bc865136 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# coding: utf-8 try: import unittest2 as unittest @@ -101,6 +102,19 @@ u"""CREATE TABLE test_table ( \tempty_column VARCHAR );""") + def test_make_create_table_statement_normalize_names(self): + csv_table = table.Table([ + table.Column(0, u'äää H!HU', [u'Chicago Reader', u'Chicago Sun-Times', u'Chicago Tribune', u'Row with blanks'])], + name='test_table') + sql_table = sql.make_table(csv_table, 'csvsql', True, None, True) + statement = sql.make_create_table_statement(sql_table) + + self.assertEqual(statement, +u"""CREATE TABLE test_table ( +\taaa_h_hu VARCHAR +);""") + + def test_make_create_table_statement_with_schema(self): sql_table = sql.make_table(self.csv_table, 'csvsql', db_schema='test_schema') statement = sql.make_create_table_statement(sql_table)