diff --git a/examples/earth/fengwu/conf/config.yaml b/examples/earth/fengwu/conf/config.yaml index 259a79f105233d1341b222d4dce77164a4216891..c1260f235a792e29e5a9beea4f6e24dd545faef5 100755 --- a/examples/earth/fengwu/conf/config.yaml +++ b/examples/earth/fengwu/conf/config.yaml @@ -20,8 +20,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" train_time: [1951, 1952] diff --git a/examples/earth/fengwu/fake_data.py b/examples/earth/fengwu/fake_data.py index 00449ec9ea99771ebf96bf35b97e7adf6fb69439..2923e980e6efcaa25aa45d8d0a80538ffb99a7ed 100644 --- a/examples/earth/fengwu/fake_data.py +++ b/examples/earth/fengwu/fake_data.py @@ -23,8 +23,8 @@ def generate_fake_h5(data_dir, var_names, years, dims): path = os.path.join(data_dir, "data", f"{year}.h5") with h5py.File(path, "w") as f: ds = f.create_dataset( - "fields", - shape=(T, C, H, W), + "fields", + shape=(T, C, H, W), dtype="float32", chunks=(1, C, H, W), fillvalue=0.0, @@ -49,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -73,17 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") - arr = np.random.randn(721, 1440).astype(np.float32) # 保存数据 - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") - + ds.to_netcdf(f"{data_dir}/{name}.nc") + arr = np.random.randn(721, 1440).astype(np.float32) + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") + + if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -91,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") + generate_stats(stats_dir, len(atm_vars)) + + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') - print("\n✅ Fake datasets generated.") diff --git a/examples/earth/fourcastnet/conf/config.yaml b/examples/earth/fourcastnet/conf/config.yaml index ca8f9c49d6a44db13471a6025a99f3bb81b006f1..d754a09a59efb31cb509b8fa94dfd8ba0644c83f 100755 --- a/examples/earth/fourcastnet/conf/config.yaml +++ b/examples/earth/fourcastnet/conf/config.yaml @@ -16,8 +16,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" train_time: [1951, 1952] diff --git a/examples/earth/fourcastnet/fake_data.py b/examples/earth/fourcastnet/fake_data.py index c9eb954b1725dda489c2d5efcecbaca445cd49c4..2923e980e6efcaa25aa45d8d0a80538ffb99a7ed 100644 --- a/examples/earth/fourcastnet/fake_data.py +++ b/examples/earth/fourcastnet/fake_data.py @@ -49,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -73,18 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") + ds.to_netcdf(f"{data_dir}/{name}.nc") arr = np.random.randn(721, 1440).astype(np.float32) - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -92,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") + generate_stats(stats_dir, len(atm_vars)) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/fuxi/conf/config.yaml b/examples/earth/fuxi/conf/config.yaml index 612f1bf0dfc4fa997976ab6650c8eb00ea8b84be..38aa0f938eb579d98cf445cdc0a6773949ee954f 100755 --- a/examples/earth/fuxi/conf/config.yaml +++ b/examples/earth/fuxi/conf/config.yaml @@ -27,8 +27,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" train_time: [1951, 1952] val_time: [1953] diff --git a/examples/earth/fuxi/fake_data.py b/examples/earth/fuxi/fake_data.py index eb0770aa9711ca44547b9324626a01f91d8c6f58..b4bf55e65c93eb96c5129746b4b6fa5a1a54bb70 100644 --- a/examples/earth/fuxi/fake_data.py +++ b/examples/earth/fuxi/fake_data.py @@ -137,13 +137,15 @@ if __name__ == "__main__": # 主 ERA5 数据 generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, n_vars) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") + generate_stats(stats_dir, n_vars) # 各阶段中间模型输出(用于 short/medium/long 微调的输入) for stage in ['base', 'short', 'medium', 'long']: generate_fake_npy(f'./result/{stage}', n_vars, years, DATASET_DIMS) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/graphcast/compute_time_diff_std.py b/examples/earth/graphcast/compute_time_diff_std.py index b7f52ef1cf84c8d39e0a2e530a63b7e939c087c1..5328cc8acf05214b37817634396777c6d28dcdc3 100755 --- a/examples/earth/graphcast/compute_time_diff_std.py +++ b/examples/earth/graphcast/compute_time_diff_std.py @@ -26,7 +26,6 @@ def main(): area = area.unsqueeze(1) mean, mean_sqr = 0, 0 - k = 0 for data in tqdm(train_dataloader): invar = data[0] # [b, N, h, w] outvar = data[1] # [b, N, h, w] @@ -35,7 +34,6 @@ def main(): weighted_diff_sqr = torch.square(weighted_diff) mean += torch.mean(weighted_diff, dim=(2, 3)) / len(train_dataloader) mean_sqr += torch.mean(weighted_diff_sqr, dim=(2, 3)) / len(train_dataloader) - k += 1 variance = mean_sqr - mean**2 # [1,num_channel, 1,1] std = torch.sqrt(variance) diff --git a/examples/earth/graphcast/conf/config.yaml b/examples/earth/graphcast/conf/config.yaml index 4d9fd44453bdf754641f790b8d717925f337aa7a..1efe2e35fc95ed4b5620b6ad9c70541f51cb91d1 100755 --- a/examples/earth/graphcast/conf/config.yaml +++ b/examples/earth/graphcast/conf/config.yaml @@ -52,8 +52,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" dataset_metadata_path: './data.json' # Path to the dataset metadata, containing channel names. time_diff_std_path: './time_diff_std.npy' # Path to the .npy file with standard deviation of normalized per-variable per-pressure level time differences. diff --git a/examples/earth/graphcast/fake_data.py b/examples/earth/graphcast/fake_data.py index 1f0737f7388fdd699e44d3fc82539afe27f82270..2923e980e6efcaa25aa45d8d0a80538ffb99a7ed 100644 --- a/examples/earth/graphcast/fake_data.py +++ b/examples/earth/graphcast/fake_data.py @@ -6,46 +6,32 @@ from onescience.utils.YParams import YParams # 各数据集固定的空间和时间维度 -DATASET_DIMS = {"T": 50, "H": 721, "W": 1440, "time_step": 6} -PATCH_HW = 32 -PATCHES_PER_STEP = 2 -RNG_SEED = 42 +DATASET_DIMS = {"T": 10, "H": 721, "W": 1440, "time_step": 6} def generate_fake_h5(data_dir, var_names, years, dims): """ - 为每个年份生成一个稀疏随机 h5 文件。 - 仅写入少量随机 patch,其余 chunk 仍保持 fill_value=0, - 这样既能让 time-diff std 非零,也能保持文件很小。 + 为每个年份生成一个空 h5 文件。 + 利用 HDF5 chunked 压缩数据集未写入 chunk 即返回 fill_value=0 的特性, + 文件实际只含元数据,极小,但 shape 与真实数据完全一致。 """ os.makedirs(os.path.join(data_dir, "data"), exist_ok=True) T, C = dims["T"], len(var_names) H, W = dims["H"], dims["W"] - patch_h = min(PATCH_HW, H) - patch_w = min(PATCH_HW, W) - for year_idx, year in enumerate(years): + for year in years: path = os.path.join(data_dir, "data", f"{year}.h5") - rng = np.random.default_rng(RNG_SEED + year_idx) with h5py.File(path, "w") as f: ds = f.create_dataset( "fields", shape=(T, C, H, W), dtype="float32", - chunks=(1, C, patch_h, patch_w), + chunks=(1, C, H, W), fillvalue=0.0, ) ds.attrs["variables"] = var_names ds.attrs["time_step"] = dims["time_step"] - for t in range(T): - for _ in range(PATCHES_PER_STEP): - top = int(rng.integers(0, H - patch_h + 1)) - left = int(rng.integers(0, W - patch_w + 1)) - patch = rng.standard_normal((C, patch_h, patch_w)).astype(np.float32) - patch += np.float32(0.1 * t) - ds[t, :, top:top + patch_h, left:left + patch_w] = patch - size_kb = os.path.getsize(path) / 1024 print(f" {year}.h5 shape=({T},{C},{H},{W}) " f"logical={T*C*H*W*4/1024**3:.1f}GB actual={size_kb:.1f}KB") @@ -63,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -87,18 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") + ds.to_netcdf(f"{data_dir}/{name}.nc") arr = np.random.randn(721, 1440).astype(np.float32) - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -106,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") + generate_stats(stats_dir, len(atm_vars)) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/pangu_weather/conf/config.yaml b/examples/earth/pangu_weather/conf/config.yaml index de1c893ad47d8c227418fd57984f539e5dec781a..8339c7965ea2f8544980b5d05b5aad603e26ae88 100755 --- a/examples/earth/pangu_weather/conf/config.yaml +++ b/examples/earth/pangu_weather/conf/config.yaml @@ -16,8 +16,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' - static_dir: './data/static/' # data_dir: './data/' train_time: [1951, 1952] diff --git a/examples/earth/pangu_weather/fake_data.py b/examples/earth/pangu_weather/fake_data.py index c9eb954b1725dda489c2d5efcecbaca445cd49c4..2923e980e6efcaa25aa45d8d0a80538ffb99a7ed 100644 --- a/examples/earth/pangu_weather/fake_data.py +++ b/examples/earth/pangu_weather/fake_data.py @@ -49,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -73,18 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") + ds.to_netcdf(f"{data_dir}/{name}.nc") arr = np.random.randn(721, 1440).astype(np.float32) - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -92,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") + generate_stats(stats_dir, len(atm_vars)) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/pangu_weather/train.py b/examples/earth/pangu_weather/train.py index 6badf83d024d3803b9a89f4cbd43c714104bf331..db4438635b6e6e6004cb6e64f74c4b4e578b2253 100755 --- a/examples/earth/pangu_weather/train.py +++ b/examples/earth/pangu_weather/train.py @@ -57,9 +57,11 @@ def main(): surface_weights = torch.as_tensor(cfg_data.dataset.weights[:4], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) pressure_weights = torch.as_tensor(cfg_data.dataset.weights[4:], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) - land_mask = torch.from_numpy(np.load(os.path.join(cfg_data.dataset.static_dir, "land_mask.npy")).astype(np.float32)) - soil_type = torch.from_numpy(np.load(os.path.join(cfg_data.dataset.static_dir, "soil_type.npy")).astype(np.float32)) - topography = torch.from_numpy(np.load(os.path.join(cfg_data.dataset.static_dir, "topography.npy")).astype(np.float32)) + static_dir = os.path.join(cfg_data.dataset.data_dir, "static") + + land_mask = torch.from_numpy(np.load(os.path.join(static_dir, "land_mask.npy")).astype(np.float32)) + soil_type = torch.from_numpy(np.load(os.path.join(static_dir, "soil_type.npy")).astype(np.float32)) + topography = torch.from_numpy(np.load(os.path.join(static_dir, "topography.npy")).astype(np.float32)) topography = (topography - topography.mean()) / (topography.std(unbiased=False) + 1e-6) surface_mask = torch.stack([land_mask, soil_type, topography], dim=0).to(local_rank) surface_mask = surface_mask.unsqueeze(0).repeat(cfg_data.dataloader.batch_size, 1, 1, 1)