import_jd_provider_tmp.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from pathlib import Path
  2. import pandas as pd
  3. from commons.conn_mysql import MySQLPoolOnline
  4. def get_conn():
  5. """
  6. 数据库连接配置(你自己补)。
  7. """
  8. return MySQLPoolOnline()
  9. def read_source_file(file_path: Path, sheet_name=0) -> pd.DataFrame:
  10. suffix = file_path.suffix.lower()
  11. if suffix in {".xlsx", ".xls"}:
  12. return pd.read_excel(file_path, sheet_name=sheet_name, dtype=object)
  13. if suffix == ".csv":
  14. return pd.read_csv(file_path, dtype=object, encoding="utf-8-sig")
  15. raise ValueError(f"不支持的文件类型: {suffix}")
  16. def get_table_columns(conn, table_name: str):
  17. sql = """
  18. SELECT COLUMN_NAME
  19. FROM information_schema.COLUMNS
  20. WHERE TABLE_SCHEMA = DATABASE()
  21. AND TABLE_NAME = %s
  22. ORDER BY ORDINAL_POSITION
  23. """
  24. # 兼容 MySQLPoolOnline(select_data)和原生 pymysql 连接(cursor)
  25. if hasattr(conn, "select_data"):
  26. rows = conn.select_data(sql, (table_name,))
  27. return [r.get("COLUMN_NAME") for r in rows]
  28. with conn.cursor() as cur:
  29. cur.execute(sql, (table_name,))
  30. rows = cur.fetchall()
  31. return [r[0] for r in rows]
  32. def clean_df(df: pd.DataFrame) -> pd.DataFrame:
  33. # 去掉空列名和列名前后空格
  34. df = df.rename(columns=lambda x: str(x).strip() if x is not None else "")
  35. df = df[[c for c in df.columns if c]]
  36. # 把 NaN 转为 None,便于写库
  37. return df.where(pd.notna(df), None)
  38. def import_to_jd_provider_tmp(file_path: str, table_name="jd_provider_tmp", sheet_name=0, batch_size=500):
  39. file = Path(file_path)
  40. if not file.exists():
  41. raise FileNotFoundError(f"文件不存在: {file}")
  42. df = clean_df(read_source_file(file, sheet_name=sheet_name))
  43. if df.empty:
  44. print("源文件无数据,跳过导入。")
  45. return
  46. conn = get_conn()
  47. is_pool_conn = hasattr(conn, "select_data") and hasattr(conn, "execute_many")
  48. try:
  49. table_columns = get_table_columns(conn, table_name)
  50. if not table_columns:
  51. raise RuntimeError(f"未找到表或无字段: {table_name}")
  52. # 仅导入“表头和数据库字段同名”的列
  53. import_columns = [c for c in df.columns if c in table_columns]
  54. if not import_columns:
  55. raise RuntimeError("文件表头与数据库字段无交集,请检查表头是否与表字段同名。")
  56. missing_columns = [c for c in df.columns if c not in table_columns]
  57. if missing_columns:
  58. print(f"以下表头不在数据库表中,已自动忽略: {missing_columns}")
  59. insert_sql = (
  60. f"INSERT INTO `{table_name}` ({', '.join([f'`{c}`' for c in import_columns])}) "
  61. f"VALUES ({', '.join(['%s'] * len(import_columns))})"
  62. )
  63. values = [tuple(row[c] for c in import_columns) for _, row in df.iterrows()]
  64. total = len(values)
  65. inserted = 0
  66. if is_pool_conn:
  67. for i in range(0, total, batch_size):
  68. batch = values[i:i + batch_size]
  69. conn.execute_many(insert_sql, batch)
  70. inserted += len(batch)
  71. print(f"已导入: {inserted}/{total}")
  72. else:
  73. with conn.cursor() as cur:
  74. for i in range(0, total, batch_size):
  75. batch = values[i:i + batch_size]
  76. cur.executemany(insert_sql, batch)
  77. inserted += len(batch)
  78. print(f"已导入: {inserted}/{total}")
  79. conn.commit()
  80. print(f"导入完成,表 `{table_name}` 共导入 {inserted} 条。")
  81. except Exception:
  82. if not is_pool_conn:
  83. conn.rollback()
  84. raise
  85. finally:
  86. # 连接池对象无需 close;原生连接需要 close
  87. if hasattr(conn, "close") and not is_pool_conn:
  88. conn.close()
  89. if __name__ == "__main__":
  90. fixed_file = Path.cwd() / "11111.xlsx"
  91. import_to_jd_provider_tmp(
  92. file_path=str(fixed_file),
  93. table_name="jd_provider_tmp",
  94. sheet_name=0,
  95. batch_size=500,
  96. )