Update discord_lnbits_bot.py
This commit is contained in:
parent
22812b628d
commit
a07cdecc32
@ -24,19 +24,12 @@ from sqlalchemy.orm import sessionmaker, relationship
|
|||||||
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
|
||||||
# ─── Logging ──────────────────────────────────────────────────────────────────
|
# ─── Logging ──────────────────────────────────────────────────────────────
|
||||||
logging.basicConfig(
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s | %(message)s")
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s %(name)s | %(message)s"
|
|
||||||
)
|
|
||||||
logger = logging.getLogger("discord_lnbits_bot")
|
logger = logging.getLogger("discord_lnbits_bot")
|
||||||
|
|
||||||
# ─── DB Setup & Config Seeding ─────────────────────────────────────────────────
|
# ─── Database & Config Seeding ────────────────────────────────────────────
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
DATABASE_URL = os.environ["DATABASE_URL"]
|
||||||
if not DATABASE_URL:
|
|
||||||
logger.critical("Missing DATABASE_URL in .env")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
engine = create_engine(DATABASE_URL, future=True)
|
engine = create_engine(DATABASE_URL, future=True)
|
||||||
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
@ -55,10 +48,7 @@ class Plan(Base):
|
|||||||
channel_id = Column(String(64), nullable=False)
|
channel_id = Column(String(64), nullable=False)
|
||||||
price_sats = Column(Integer, nullable=False)
|
price_sats = Column(Integer, nullable=False)
|
||||||
invoice_message = Column(Text, nullable=False)
|
invoice_message = Column(Text, nullable=False)
|
||||||
expiry_type = Column(
|
expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False)
|
||||||
Enum("fixed_date", "rolling_days", "calendar_month", name="expiry_type"),
|
|
||||||
nullable=False
|
|
||||||
)
|
|
||||||
duration_days = Column(Integer, default=30)
|
duration_days = Column(Integer, default=30)
|
||||||
subs = relationship("Subscription", back_populates="plan")
|
subs = relationship("Subscription", back_populates="plan")
|
||||||
|
|
||||||
@ -69,10 +59,7 @@ class Subscription(Base):
|
|||||||
user_id = Column(String(64), nullable=False)
|
user_id = Column(String(64), nullable=False)
|
||||||
invoice_hash = Column(String(256), unique=True, nullable=False)
|
invoice_hash = Column(String(256), unique=True, nullable=False)
|
||||||
invoice_request = Column(Text, nullable=False)
|
invoice_request = Column(Text, nullable=False)
|
||||||
status = Column(
|
status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending")
|
||||||
Enum("pending","active","expired","cancelled", name="sub_status"),
|
|
||||||
nullable=False, default="pending"
|
|
||||||
)
|
|
||||||
created_at = Column(DateTime, default=datetime.utcnow)
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
paid_at = Column(DateTime)
|
paid_at = Column(DateTime)
|
||||||
expires_at = Column(DateTime)
|
expires_at = Column(DateTime)
|
||||||
@ -82,10 +69,10 @@ class Subscription(Base):
|
|||||||
raw_payment_json = Column(JSONB)
|
raw_payment_json = Column(JSONB)
|
||||||
plan = relationship("Plan", back_populates="subs")
|
plan = relationship("Plan", back_populates="subs")
|
||||||
|
|
||||||
# create tables & seed Config keys
|
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
logger.info("✅ Database schema ready.")
|
logger.info("✅ Database schema ready.")
|
||||||
|
|
||||||
|
# Seed Config keys (if missing)
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
|
for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
|
||||||
if not session.get(Config, key):
|
if not session.get(Config, key):
|
||||||
@ -93,30 +80,25 @@ for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
|
|||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
# ─── Config helper ────────────────────────────────────────────────────────────
|
# ─── Config helper ─────────────────────────────────────────────────────────
|
||||||
def get_cfg(key: str) -> str:
|
def get_cfg(k: str) -> str:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
row = db.get(Config, key)
|
row = db.get(Config, k)
|
||||||
db.close()
|
db.close()
|
||||||
return row.value if row else ""
|
return row.value if row else ""
|
||||||
|
|
||||||
# Read core settings from DB
|
|
||||||
DISCORD_TOKEN = get_cfg("discord_token")
|
DISCORD_TOKEN = get_cfg("discord_token")
|
||||||
GUILD_ID = int(get_cfg("guild_id") or 0)
|
GUILD_ID = int(get_cfg("guild_id") or 0)
|
||||||
LNBITS_URL = get_cfg("lnbits_url")
|
LNBITS_URL = get_cfg("lnbits_url")
|
||||||
LNBITS_API_KEY = get_cfg("lnbits_api_key")
|
LNBITS_API_KEY = get_cfg("lnbits_api_key")
|
||||||
|
|
||||||
for name, val in (
|
for name, val in (("discord_token", DISCORD_TOKEN), ("guild_id", GUILD_ID),
|
||||||
("discord_token", DISCORD_TOKEN),
|
("lnbits_url", LNBITS_URL), ("lnbits_api_key", LNBITS_API_KEY)):
|
||||||
("guild_id", GUILD_ID),
|
|
||||||
("lnbits_url", LNBITS_URL),
|
|
||||||
("lnbits_api_key", LNBITS_API_KEY),
|
|
||||||
):
|
|
||||||
if not val:
|
if not val:
|
||||||
logger.critical(f"Config '{name}' is empty in DB. Fill it via the UI first.")
|
logger.critical(f"Config '{name}' is empty in DB. Fill via UI first.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# Build WebSocket URL
|
# Build WS URL
|
||||||
if LNBITS_URL.startswith("https://"):
|
if LNBITS_URL.startswith("https://"):
|
||||||
ws_base = LNBITS_URL.replace("https://", "wss://", 1)
|
ws_base = LNBITS_URL.replace("https://", "wss://", 1)
|
||||||
elif LNBITS_URL.startswith("http://"):
|
elif LNBITS_URL.startswith("http://"):
|
||||||
@ -126,121 +108,76 @@ else:
|
|||||||
exit(1)
|
exit(1)
|
||||||
LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}"
|
LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}"
|
||||||
|
|
||||||
# ─── Discord Bot Setup ────────────────────────────────────────────────────────
|
# ─── Discord Bot ──────────────────────────────────────────────────────────
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default(); intents.members = True
|
||||||
intents.members = True
|
|
||||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||||
|
|
||||||
# ─── Expiry computation ────────────────────────────────────────────────────────
|
|
||||||
def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
|
def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
|
||||||
if plan.expiry_type == "calendar_month":
|
if plan.expiry_type == "calendar_month":
|
||||||
first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||||
return first + relativedelta(months=1)
|
return first + relativedelta(months=1)
|
||||||
if plan.expiry_type == "rolling_days":
|
if plan.expiry_type == "rolling_days":
|
||||||
return paid_at + relativedelta(days=plan.duration_days)
|
return paid_at + relativedelta(days=plan.duration_days)
|
||||||
return paid_at # fixed_date handled via UI
|
return paid_at
|
||||||
|
|
||||||
# ─── Handle plan purchase ─────────────────────────────────────────────────────
|
async def handle_plan_purchase(interaction, plan: Plan):
|
||||||
async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
|
|
||||||
invoice_ch = bot.get_channel(int(plan.channel_id))
|
invoice_ch = bot.get_channel(int(plan.channel_id))
|
||||||
if not invoice_ch:
|
if not invoice_ch:
|
||||||
return await interaction.response.send_message(
|
return await interaction.response.send_message("❌ Invoice channel not found.", ephemeral=True)
|
||||||
"❌ Invoice channel not found.", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
invoice_data = {
|
|
||||||
"out": False,
|
|
||||||
"amount": plan.price_sats,
|
|
||||||
"memo": plan.invoice_message
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"X-Api-Key": LNBITS_API_KEY,
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
invoice_data = {"out": False, "amount": plan.price_sats, "memo": plan.invoice_message}
|
||||||
|
headers = {"X-Api-Key": LNBITS_API_KEY, "Content-Type": "application/json"}
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
def create_invoice():
|
def create_invoice():
|
||||||
r = requests.post(
|
r = requests.post(f"{LNBITS_URL}/api/v1/payments", json=invoice_data, headers=headers, timeout=10)
|
||||||
f"{LNBITS_URL}/api/v1/payments",
|
|
||||||
json=invoice_data, headers=headers, timeout=10
|
|
||||||
)
|
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
inv = await loop.run_in_executor(None, create_invoice)
|
inv = await loop.run_in_executor(None, create_invoice)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Invoice creation failed")
|
logger.exception("Invoice creation failed")
|
||||||
return await interaction.response.send_message(
|
return await interaction.response.send_message("❌ Could not generate invoice.", ephemeral=True)
|
||||||
"❌ Failed to generate invoice.", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
pay_req = inv["bolt11"]
|
|
||||||
pay_hash = inv["payment_hash"]
|
|
||||||
|
|
||||||
|
pay_req, pay_hash = inv["bolt11"], inv["payment_hash"]
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
sub = Subscription(
|
sub = Subscription(plan_id=plan.id, user_id=str(interaction.user.id),
|
||||||
plan_id = plan.id,
|
invoice_hash=pay_hash, invoice_request=pay_req,
|
||||||
user_id = str(interaction.user.id),
|
raw_invoice_json=inv)
|
||||||
invoice_hash = pay_hash,
|
|
||||||
invoice_request = pay_req,
|
|
||||||
raw_invoice_json= inv
|
|
||||||
)
|
|
||||||
db.add(sub); db.commit(); db.close()
|
db.add(sub); db.commit(); db.close()
|
||||||
logger.info(f"Stored pending {pay_hash} for user {interaction.user.id}")
|
|
||||||
|
|
||||||
def make_qr_buf():
|
def make_qr_buf():
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO(); qrcode.make(pay_req).save(buf, format="PNG"); buf.seek(0); return buf
|
||||||
qrcode.make(pay_req).save(buf, format="PNG")
|
|
||||||
buf.seek(0)
|
|
||||||
return buf
|
|
||||||
|
|
||||||
qr_buf = await loop.run_in_executor(None, make_qr_buf)
|
qr_buf = await loop.run_in_executor(None, make_qr_buf)
|
||||||
qr_file = File(qr_buf, filename="invoice.png")
|
qr_file= File(qr_buf, "invoice.png")
|
||||||
|
|
||||||
embed = Embed(
|
embed = Embed(title=f"⚡ Pay {plan.price_sats} sats for {plan.name} ⚡",
|
||||||
title=f"⚡ Pay {plan.price_sats} sats for {plan.name} ⚡",
|
description=plan.invoice_message, color=discord.Color.gold())
|
||||||
description=plan.invoice_message,
|
|
||||||
color=discord.Color.gold()
|
|
||||||
)
|
|
||||||
embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False)
|
embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False)
|
||||||
embed.set_image(url="attachment://invoice.png")
|
embed.set_image(url="attachment://invoice.png")
|
||||||
embed.set_footer(text=f"Hash: {pay_hash}")
|
embed.set_footer(text=f"Hash: {pay_hash}")
|
||||||
|
|
||||||
await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file)
|
await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file)
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True)
|
||||||
f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# ─── WS listener ───────────────────────────────────────────────────────────────
|
|
||||||
async def lnbits_ws_listener():
|
async def lnbits_ws_listener():
|
||||||
await bot.wait_until_ready()
|
await bot.wait_until_ready()
|
||||||
logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}")
|
logger.info(f"Listening on WS {LNBITS_WS_URL}")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
async with websockets.connect(LNBITS_WS_URL) as ws:
|
async with websockets.connect(LNBITS_WS_URL) as ws:
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
data = json.loads(msg)
|
data = json.loads(msg); pay = data.get("payment", {})
|
||||||
pay = data.get("payment", {})
|
|
||||||
hsh = pay.get("checking_id") or pay.get("payment_hash")
|
hsh = pay.get("checking_id") or pay.get("payment_hash")
|
||||||
if not (hsh and pay.get("status")=="success"):
|
if not (hsh and pay.get("status")=="success"): continue
|
||||||
continue
|
|
||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
sub = db.query(Subscription)\
|
sub= db.query(Subscription).filter_by(invoice_hash=hsh, status="pending").first()
|
||||||
.filter_by(invoice_hash=hsh, status="pending")\
|
if not sub: db.close(); continue
|
||||||
.first()
|
|
||||||
if not sub:
|
|
||||||
db.close(); continue
|
|
||||||
|
|
||||||
paid_at = datetime.utcnow()
|
paid_at= datetime.utcnow()
|
||||||
expires_at = compute_expiry(paid_at, sub.plan)
|
expires_at= compute_expiry(paid_at, sub.plan)
|
||||||
|
sub.status=sub.paid_at=paid_at; sub.expires_at=expires_at
|
||||||
sub.status = "active"
|
sub.raw_payment_json=pay; db.commit(); db.close()
|
||||||
sub.paid_at = paid_at
|
|
||||||
sub.expires_at = expires_at
|
|
||||||
sub.raw_payment_json = pay
|
|
||||||
db.commit(); db.close()
|
|
||||||
|
|
||||||
guild = bot.get_guild(GUILD_ID)
|
guild = bot.get_guild(GUILD_ID)
|
||||||
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
|
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
|
||||||
@ -248,21 +185,14 @@ async def lnbits_ws_listener():
|
|||||||
chan = bot.get_channel(int(sub.plan.channel_id))
|
chan = bot.get_channel(int(sub.plan.channel_id))
|
||||||
|
|
||||||
await member.add_roles(role, reason="Subscription paid")
|
await member.add_roles(role, reason="Subscription paid")
|
||||||
await chan.send(
|
await chan.send(f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** "
|
||||||
f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** "
|
f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}.")
|
||||||
f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}."
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("WS error; reconnecting in 10s")
|
logger.exception("WS error, reconnecting in 10s"); await asyncio.sleep(10)
|
||||||
await asyncio.sleep(10)
|
|
||||||
|
|
||||||
# ─── Cleanup expired ───────────────────────────────────────────────────────────
|
|
||||||
async def cleanup_expired():
|
async def cleanup_expired():
|
||||||
now = datetime.utcnow()
|
now= datetime.utcnow(); db= SessionLocal()
|
||||||
db = SessionLocal()
|
rows= db.query(Subscription).filter(Subscription.status=="active", Subscription.expires_at<=now).all()
|
||||||
rows = db.query(Subscription)\
|
|
||||||
.filter(Subscription.status=="active", Subscription.expires_at <= now)\
|
|
||||||
.all()
|
|
||||||
for sub in rows:
|
for sub in rows:
|
||||||
guild = bot.get_guild(GUILD_ID)
|
guild = bot.get_guild(GUILD_ID)
|
||||||
member = guild.get_member(int(sub.user_id))
|
member = guild.get_member(int(sub.user_id))
|
||||||
@ -270,38 +200,26 @@ async def cleanup_expired():
|
|||||||
if member and role in member.roles:
|
if member and role in member.roles:
|
||||||
try:
|
try:
|
||||||
await member.remove_roles(role, reason="Subscription expired")
|
await member.remove_roles(role, reason="Subscription expired")
|
||||||
sub.status = "expired"
|
sub.status="expired"; sub.role_removed_at=now; db.commit()
|
||||||
sub.role_removed_at = now
|
logger.info(f"Removed expired role from {member.display_name}")
|
||||||
db.commit()
|
|
||||||
logger.info(f"Removed expired {role.name} from {member.display_name}")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error removing expired role")
|
logger.exception("Failed to remove expired role")
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# ─── on_ready & scheduler ─────────────────────────────────────────────────────
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_ready():
|
async def on_ready():
|
||||||
logger.info(f"Bot logged in as {bot.user} (ID {bot.user.id})")
|
logger.info(f"Logged in as {bot.user}")
|
||||||
|
db = SessionLocal(); plans = db.query(Plan).all()
|
||||||
# dynamic command registration
|
|
||||||
db = SessionLocal()
|
|
||||||
plans = db.query(Plan).all()
|
|
||||||
for plan in plans:
|
for plan in plans:
|
||||||
@bot.tree.command(name=plan.command_name,
|
@bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}")
|
||||||
description=f"Buy {plan.name} for {plan.price_sats} sats")
|
async def _cmd(interaction, plan=plan):
|
||||||
async def _cmd(interaction: discord.Interaction, plan=plan):
|
|
||||||
await handle_plan_purchase(interaction, plan)
|
await handle_plan_purchase(interaction, plan)
|
||||||
db.close()
|
db.close()
|
||||||
await bot.tree.sync()
|
await bot.tree.sync()
|
||||||
|
|
||||||
# start WS listener
|
|
||||||
bot.loop.create_task(lnbits_ws_listener())
|
bot.loop.create_task(lnbits_ws_listener())
|
||||||
|
sched= AsyncIOScheduler(timezone="UTC")
|
||||||
# schedule midnight‐UTC expiry check
|
|
||||||
sched = AsyncIOScheduler(timezone="UTC")
|
|
||||||
sched.add_job(cleanup_expired, "cron", hour=0, minute=0)
|
sched.add_job(cleanup_expired, "cron", hour=0, minute=0)
|
||||||
sched.start()
|
sched.start()
|
||||||
|
|
||||||
# ─── Run ───────────────────────────────────────────────────────────────────────
|
if __name__=="__main__":
|
||||||
if __name__ == "__main__":
|
|
||||||
bot.run(DISCORD_TOKEN)
|
bot.run(DISCORD_TOKEN)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user