Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pyro
- import pyro.distributions as dist
- import torch
- # =========================
- # --- CONFIGURATION ------
- # =========================
- AGENT_NAMES = [
- 'Trump', 'Khamenei', 'IRGC', 'Israel', 'MAGA', 'Neocons',
- 'China', 'Russia', 'EU'
- ]
- AGENT_INIT = {
- 'intent': {'mean': 0.5, 'std': 0.15},
- 'capability':{'mean': 0.7, 'std': 0.1},
- 'risk_tolerance':{'mean': 0.5, 'std': 0.13},
- 'ideology': {'mean': 0.6, 'std': 0.15}
- }
- INFLUENCE_NETWORK = {
- 'Trump': [('Trump', 'w_trump_self'), ('MAGA', 'w_trump_maga'), ('Neocons', 'w_trump_neocons')],
- 'Khamenei': [('Khamenei', 'w_khamenei_self'), ('IRGC', 'w_khamenei_irgc')],
- 'Israel': [('Israel', 'w_israel_self')],
- 'IRGC': [('IRGC', 'w_irgc_self')],
- 'MAGA': [('MAGA', 'w_maga_self')],
- 'Neocons': [('Neocons', 'w_neocons_self')],
- 'China': [('China', 'w_china_self')],
- 'Russia': [('Russia', 'w_russia_self')],
- 'EU': [('EU', 'w_eu_self')]
- }
- INFLUENCE_WEIGHTS = {
- 'w_trump_self': 0.6, 'w_trump_maga': 0.25, 'w_trump_neocons': 0.15,
- 'w_khamenei_self': 0.7, 'w_khamenei_irgc': 0.3,
- 'w_israel_self': 1.0,
- 'w_irgc_self': 1.0,
- 'w_maga_self': 1.0,
- 'w_neocons_self': 1.0,
- 'w_china_self': 1.0,
- 'w_russia_self': 1.0,
- 'w_eu_self': 1.0
- }
- WORLD_INIT = {
- 'oil_price': {'mean': 85, 'std': 5},
- 'geopolitical_tension': {'a': 2, 'b': 5},
- 'global_sentiment': {'mean': 0, 'std': 1},
- 'regional_conflict': [0.6, 0.3, 0.1],
- 'supply_chain_disruption': 0.2,
- 'inflation': {'mean': 5, 'std': 1},
- 'gdp_iran': {'mean': 400, 'std': 80},
- 'defence_budget_iran': {'mean': 20, 'std': 4},
- 'currency_iran': {'mean': 500000, 'std': 50000}
- }
- EVENT_PARAMS = {
- 'us_bombed_nuclear_sites': {'weights': [2.2, 1.5, 1.0, 1.0, 1.4]},
- 'israel_bombed_nuclear_sites':{'weights': [2.3, 1.2, 1.3]},
- 'khamenei_assassinated': {'weights': [1.5, 1.3, 1.2]},
- 'nuclear_deal_signed': {'weights': [1.0, 1.0, 1.0, 1.0]},
- 'snapback_activated': {'weights': [1.3, 0.9, 0.9, 0.6, 0.6]},
- 'iran_covert_disrupt_region': {'weights': [1.8, 1.2, 1.1]},
- 'iran_covert_terrorist_attack_on_west': {'weights': [1.6, 1.1, 1.1]},
- 'sanctions_imposed': {'weights': [1.2, 1.2, 1.0, 0.6]},
- 'oil_supply_disrupted': {'weights': [1.3, 0.9, 1.1, 0.8]},
- 'major_protest': {'weights': [1.0, 1.0, 0.7, 0.7]},
- }
- def sigmoid_weighted_sum(weights, variables):
- z = sum(w * v for w, v in zip(weights, variables))
- return torch.sigmoid(torch.tensor(z))
- def agent_weighted_intent(agent, agent_vars, t, weights):
- total = 0.0
- infl_list = INFLUENCE_NETWORK[agent]
- for infl, weight_name in infl_list:
- w = weights[weight_name]
- total += w * agent_vars[infl][f"intent_{infl}_{t}"]
- return total
- def model(T=8):
- oil_price = pyro.sample("oil_price_0", dist.Normal(WORLD_INIT['oil_price']['mean'], WORLD_INIT['oil_price']['std']))
- tension = pyro.sample("geopolitical_tension_0", dist.Beta(WORLD_INIT['geopolitical_tension']['a'], WORLD_INIT['geopolitical_tension']['b']))
- global_sentiment = pyro.sample("global_sentiment_0", dist.Normal(WORLD_INIT['global_sentiment']['mean'], WORLD_INIT['global_sentiment']['std']))
- regional_conflict = pyro.sample("regional_conflict_0", dist.Categorical(torch.tensor(WORLD_INIT['regional_conflict'])))
- supply_chain = pyro.sample("supply_chain_disruption_0", dist.Bernoulli(WORLD_INIT['supply_chain_disruption']))
- inflation = pyro.sample("inflation_0", dist.Normal(WORLD_INIT['inflation']['mean'], WORLD_INIT['inflation']['std']))
- gdp_iran = pyro.sample("gdp_iran_0", dist.Normal(WORLD_INIT['gdp_iran']['mean'], WORLD_INIT['gdp_iran']['std']))
- defence_budget_iran = pyro.sample("defence_budget_iran_0", dist.Normal(WORLD_INIT['defence_budget_iran']['mean'], WORLD_INIT['defence_budget_iran']['std']))
- currency_iran = pyro.sample("currency_iran_0", dist.Normal(WORLD_INIT['currency_iran']['mean'], WORLD_INIT['currency_iran']['std']))
- agent_vars = {}
- for agent in AGENT_NAMES:
- agent_vars[agent] = {}
- for var in AGENT_INIT:
- agent_vars[agent][f"{var}_{agent}_0"] = pyro.sample(f"{var}_{agent}_0", dist.Normal(AGENT_INIT[var]['mean'], AGENT_INIT[var]['std']).to_event(0))
- for t in range(1, T+1):
- for agent in AGENT_NAMES:
- for var in AGENT_INIT:
- prev_val = agent_vars[agent][f"{var}_{agent}_{t-1}"]
- agent_vars[agent][f"{var}_{agent}_{t}"] = pyro.sample(
- f"{var}_{agent}_{t}",
- dist.Normal(0.85*prev_val + 0.15*torch.rand(1), AGENT_INIT[var]['std']).to_event(0)
- )
- weighted_intents = {}
- for agent in AGENT_NAMES:
- weighted_intents[agent] = agent_weighted_intent(agent, agent_vars, t, INFLUENCE_WEIGHTS)
- oil_price = pyro.sample(f"oil_price_{t}", dist.Normal(oil_price + 2*torch.rand(1) - 1, WORLD_INIT['oil_price']['std']).to_event(0))
- tension = pyro.sample(f"geopolitical_tension_{t}", dist.Normal(0.85*tension + 0.15*torch.rand(1), 0.06).to_event(0))
- global_sentiment = pyro.sample(f"global_sentiment_{t}", dist.Normal(0.8*global_sentiment + 0.2*torch.randn(1), 0.1).to_event(0))
- inflation = pyro.sample(f"inflation_{t}", dist.Normal(0.9*inflation + 0.1*torch.rand(1), WORLD_INIT['inflation']['std']).to_event(0))
- prob_us_bomb = sigmoid_weighted_sum(
- EVENT_PARAMS['us_bombed_nuclear_sites']['weights'],
- [
- weighted_intents['Trump'],
- agent_vars['Trump'][f"capability_Trump_{t}"],
- agent_vars['Trump'][f"risk_tolerance_Trump_{t}"],
- tension,
- agent_vars['Trump'][f"ideology_Trump_{t}"]
- ]
- )
- pyro.sample(f"us_bombed_nuclear_sites_{t}", dist.Bernoulli(prob_us_bomb))
- prob_israel_bomb = sigmoid_weighted_sum(
- EVENT_PARAMS['israel_bombed_nuclear_sites']['weights'],
- [
- weighted_intents['Israel'],
- agent_vars['Israel'][f"capability_Israel_{t}"],
- tension
- ]
- )
- pyro.sample(f"israel_bombed_nuclear_sites_{t}", dist.Bernoulli(prob_israel_bomb))
- direct_conflict = 1.0 if regional_conflict==2 else 0.0
- prob_khamenei_assass = sigmoid_weighted_sum(
- EVENT_PARAMS['khamenei_assassinated']['weights'],
- [
- weighted_intents['Israel'],
- tension,
- torch.tensor(direct_conflict)
- ]
- )
- pyro.sample(f"khamenei_assassinated_{t}", dist.Bernoulli(prob_khamenei_assass))
- enrichment_ended = 0
- prob_nuclear_deal = sigmoid_weighted_sum(
- EVENT_PARAMS['nuclear_deal_signed']['weights'],
- [
- weighted_intents['Khamenei'],
- weighted_intents['Trump'],
- global_sentiment,
- torch.tensor(enrichment_ended)
- ]
- )
- pyro.sample(f"nuclear_deal_signed_{t}", dist.Bernoulli(prob_nuclear_deal))
- enrichment_ongoing = 0
- prob_snapback = sigmoid_weighted_sum(
- EVENT_PARAMS['snapback_activated']['weights'],
- [
- weighted_intents['Trump'],
- weighted_intents['Neocons'],
- tension,
- torch.tensor(enrichment_ongoing),
- weighted_intents['EU']
- ]
- )
- pyro.sample(f"snapback_activated_{t}", dist.Bernoulli(prob_snapback))
- prob_disrupt = sigmoid_weighted_sum(
- EVENT_PARAMS['iran_covert_disrupt_region']['weights'],
- [
- weighted_intents['IRGC'],
- tension,
- torch.tensor(direct_conflict)
- ]
- )
- pyro.sample(f"iran_covert_disrupt_region_{t}", dist.Bernoulli(prob_disrupt))
- prob_attack_west = sigmoid_weighted_sum(
- EVENT_PARAMS['iran_covert_terrorist_attack_on_west']['weights'],
- [
- weighted_intents['IRGC'],
- tension,
- weighted_intents['Trump']
- ]
- )
- pyro.sample(f"iran_covert_terrorist_attack_on_west_{t}", dist.Bernoulli(prob_attack_west))
- prob_sanctions = sigmoid_weighted_sum(
- EVENT_PARAMS['sanctions_imposed']['weights'],
- [
- weighted_intents['Trump'],
- weighted_intents['Neocons'],
- tension,
- weighted_intents['EU']
- ]
- )
- pyro.sample(f"sanctions_imposed_{t}", dist.Bernoulli(prob_sanctions))
- prob_oil_disrupt = sigmoid_weighted_sum(
- EVENT_PARAMS['oil_supply_disrupted']['weights'],
- [
- weighted_intents['IRGC'],
- weighted_intents['Trump'],
- torch.tensor(direct_conflict),
- weighted_intents['Russia']
- ]
- )
- pyro.sample(f"oil_supply_disrupted_{t}", dist.Bernoulli(prob_oil_disrupt))
- prob_protest = sigmoid_weighted_sum(
- EVENT_PARAMS['major_protest']['weights'],
- [
- weighted_intents['Khamenei'],
- weighted_intents['IRGC'],
- inflation / 10,
- prob_sanctions
- ]
- )
- pyro.sample(f"major_protest_{t}", dist.Bernoulli(prob_protest))
- logits_regional = torch.stack([
- 1.0 - tension,
- weighted_intents['IRGC'] + weighted_intents['Israel'] + tension,
- 1.1 * tension + 0.7 * prob_protest
- ])
- regional_conflict = pyro.sample(f"regional_conflict_{t}", dist.Categorical(torch.softmax(logits_regional, 0)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement