Update discord_lnbits_bot.py
This commit is contained in:
parent
c914dc1753
commit
10e51b3655
@ -7,7 +7,7 @@ import requests
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dotenv import load_dotenv
|
||||
import qrcode
|
||||
|
||||
import discord
|
||||
from discord import File, Embed
|
||||
@ -18,6 +18,7 @@ import websockets
|
||||
from sqlalchemy import (
|
||||
create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, relationship
|
||||
|
||||
@ -30,18 +31,16 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("discord_lnbits_bot")
|
||||
|
||||
# ─── Load minimal .env & DB setup ─────────────────────────────────────────────
|
||||
load_dotenv() # for DATABASE_URL only
|
||||
# ─── DB Setup & Config Seeding ─────────────────────────────────────────────────
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
if not DATABASE_URL:
|
||||
logger.critical("Environment variable DATABASE_URL is missing.")
|
||||
logger.critical("Missing DATABASE_URL in .env")
|
||||
exit(1)
|
||||
|
||||
engine = create_engine(DATABASE_URL, future=True)
|
||||
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
Base = declarative_base()
|
||||
|
||||
# ─── ORM Models ───────────────────────────────────────────────────────────────
|
||||
class Config(Base):
|
||||
__tablename__ = "config"
|
||||
key = Column(String, primary_key=True)
|
||||
@ -56,7 +55,10 @@ class Plan(Base):
|
||||
channel_id = Column(String(64), nullable=False)
|
||||
price_sats = Column(Integer, nullable=False)
|
||||
invoice_message = Column(Text, nullable=False)
|
||||
expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False)
|
||||
expiry_type = Column(
|
||||
Enum("fixed_date", "rolling_days", "calendar_month", name="expiry_type"),
|
||||
nullable=False
|
||||
)
|
||||
duration_days = Column(Integer, default=30)
|
||||
subs = relationship("Subscription", back_populates="plan")
|
||||
|
||||
@ -67,25 +69,36 @@ class Subscription(Base):
|
||||
user_id = Column(String(64), nullable=False)
|
||||
invoice_hash = Column(String(256), unique=True, nullable=False)
|
||||
invoice_request = Column(Text, nullable=False)
|
||||
status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending")
|
||||
status = Column(
|
||||
Enum("pending","active","expired","cancelled", name="sub_status"),
|
||||
nullable=False, default="pending"
|
||||
)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
paid_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
role_assigned_at = Column(DateTime)
|
||||
role_removed_at = Column(DateTime)
|
||||
raw_invoice_json = Column(Text)
|
||||
raw_payment_json = Column(Text)
|
||||
raw_invoice_json = Column(JSONB)
|
||||
raw_payment_json = Column(JSONB)
|
||||
plan = relationship("Plan", back_populates="subs")
|
||||
|
||||
# create tables & seed Config keys
|
||||
Base.metadata.create_all(engine)
|
||||
logger.info("✅ Database schema ready.")
|
||||
|
||||
session = SessionLocal()
|
||||
for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
|
||||
if not session.get(Config, key):
|
||||
session.add(Config(key=key, value=""))
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
# ─── Config helper ────────────────────────────────────────────────────────────
|
||||
def get_cfg(key: str) -> str:
|
||||
session = SessionLocal()
|
||||
cfg = session.get(Config, key)
|
||||
session.close()
|
||||
return cfg.value if cfg else ""
|
||||
db = SessionLocal()
|
||||
row = db.get(Config, key)
|
||||
db.close()
|
||||
return row.value if row else ""
|
||||
|
||||
# Read core settings from DB
|
||||
DISCORD_TOKEN = get_cfg("discord_token")
|
||||
@ -93,15 +106,14 @@ GUILD_ID = int(get_cfg("guild_id") or 0)
|
||||
LNBITS_URL = get_cfg("lnbits_url")
|
||||
LNBITS_API_KEY = get_cfg("lnbits_api_key")
|
||||
|
||||
# Validate
|
||||
for var_name, val in [
|
||||
for name, val in (
|
||||
("discord_token", DISCORD_TOKEN),
|
||||
("guild_id", GUILD_ID),
|
||||
("lnbits_url", LNBITS_URL),
|
||||
("lnbits_api_key", LNBITS_API_KEY),
|
||||
]:
|
||||
):
|
||||
if not val:
|
||||
logger.critical(f"Config key '{var_name}' is missing or empty in DB.")
|
||||
logger.critical(f"Config '{name}' is empty in DB. Fill it via the UI first.")
|
||||
exit(1)
|
||||
|
||||
# Build WebSocket URL
|
||||
@ -119,17 +131,16 @@ intents = discord.Intents.default()
|
||||
intents.members = True
|
||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||
|
||||
# ─── Utility: expiry computation ───────────────────────────────────────────────
|
||||
# ─── Expiry computation ────────────────────────────────────────────────────────
|
||||
def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
|
||||
if plan.expiry_type == "calendar_month":
|
||||
first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
return first + relativedelta(months=1)
|
||||
elif plan.expiry_type == "rolling_days":
|
||||
if plan.expiry_type == "rolling_days":
|
||||
return paid_at + relativedelta(days=plan.duration_days)
|
||||
else: # fixed_date
|
||||
return paid_at # assume UI sets exact expires_at elsewhere
|
||||
return paid_at # fixed_date handled via UI
|
||||
|
||||
# ─── Handle a plan purchase ───────────────────────────────────────────────────
|
||||
# ─── Handle plan purchase ─────────────────────────────────────────────────────
|
||||
async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
|
||||
invoice_ch = bot.get_channel(int(plan.channel_id))
|
||||
if not invoice_ch:
|
||||
@ -149,47 +160,40 @@ async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
def create_invoice():
|
||||
resp = requests.post(
|
||||
r = requests.post(
|
||||
f"{LNBITS_URL}/api/v1/payments",
|
||||
json=invoice_data,
|
||||
headers=headers,
|
||||
timeout=10
|
||||
json=invoice_data, headers=headers, timeout=10
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
try:
|
||||
inv = await loop.run_in_executor(None, create_invoice)
|
||||
except Exception:
|
||||
logger.exception("Invoice creation failed")
|
||||
return await interaction.response.send_message(
|
||||
"❌ Failed to generate invoice. Try again later.",
|
||||
ephemeral=True
|
||||
"❌ Failed to generate invoice.", ephemeral=True
|
||||
)
|
||||
|
||||
pay_req = inv["bolt11"]
|
||||
pay_hash = inv["payment_hash"]
|
||||
|
||||
# Store pending subscription
|
||||
db = SessionLocal()
|
||||
sub = Subscription(
|
||||
plan_id = plan.id,
|
||||
user_id = str(interaction.user.id),
|
||||
invoice_hash = pay_hash,
|
||||
invoice_request = pay_req,
|
||||
raw_invoice_json= json.dumps(inv)
|
||||
raw_invoice_json= inv
|
||||
)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.close()
|
||||
logger.info(f"Stored pending subscription {pay_hash} for user {interaction.user.id}")
|
||||
db.add(sub); db.commit(); db.close()
|
||||
logger.info(f"Stored pending {pay_hash} for user {interaction.user.id}")
|
||||
|
||||
# Generate QR code
|
||||
def make_qr_buf():
|
||||
img = requests.compat.BytesIO()
|
||||
qrcode.make(pay_req).save(img, format="PNG")
|
||||
img.seek(0)
|
||||
return img
|
||||
buf = io.BytesIO()
|
||||
qrcode.make(pay_req).save(buf, format="PNG")
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
qr_buf = await loop.run_in_executor(None, make_qr_buf)
|
||||
qr_file = File(qr_buf, filename="invoice.png")
|
||||
@ -208,7 +212,7 @@ async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
|
||||
f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True
|
||||
)
|
||||
|
||||
# ─── WebSocket listener for payments ──────────────────────────────────────────
|
||||
# ─── WS listener ───────────────────────────────────────────────────────────────
|
||||
async def lnbits_ws_listener():
|
||||
await bot.wait_until_ready()
|
||||
logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}")
|
||||
@ -216,21 +220,18 @@ async def lnbits_ws_listener():
|
||||
try:
|
||||
async with websockets.connect(LNBITS_WS_URL) as ws:
|
||||
async for msg in ws:
|
||||
try:
|
||||
data = json.loads(msg)
|
||||
pay = data.get("payment", {})
|
||||
hsh = pay.get("checking_id") or pay.get("payment_hash")
|
||||
if not (hsh and pay.get("status") == "success"):
|
||||
if not (hsh and pay.get("status")=="success"):
|
||||
continue
|
||||
|
||||
# Activate subscription
|
||||
db = SessionLocal()
|
||||
sub = db.query(Subscription)\
|
||||
.filter_by(invoice_hash=hsh, status="pending")\
|
||||
.first()
|
||||
if not sub:
|
||||
db.close()
|
||||
continue
|
||||
db.close(); continue
|
||||
|
||||
paid_at = datetime.utcnow()
|
||||
expires_at = compute_expiry(paid_at, sub.plan)
|
||||
@ -238,11 +239,9 @@ async def lnbits_ws_listener():
|
||||
sub.status = "active"
|
||||
sub.paid_at = paid_at
|
||||
sub.expires_at = expires_at
|
||||
sub.raw_payment_json = json.dumps(pay)
|
||||
db.commit()
|
||||
db.close()
|
||||
sub.raw_payment_json = pay
|
||||
db.commit(); db.close()
|
||||
|
||||
# Assign role & announce
|
||||
guild = bot.get_guild(GUILD_ID)
|
||||
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
|
||||
role = guild.get_role(int(sub.plan.role_id))
|
||||
@ -250,23 +249,21 @@ async def lnbits_ws_listener():
|
||||
|
||||
await member.add_roles(role, reason="Subscription paid")
|
||||
await chan.send(
|
||||
f"🎉 {member.mention} paid **{sub.plan.price_sats} sats**"
|
||||
f" for **{sub.plan.name}**! Expires {expires_at:%Y-%m-%d}."
|
||||
f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** "
|
||||
f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}."
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error processing WS message")
|
||||
except Exception:
|
||||
logger.exception("WS connection error, retry in 10s")
|
||||
logger.exception("WS error; reconnecting in 10s")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# ─── Cleanup expired subscriptions ─────────────────────────────────────────────
|
||||
# ─── Cleanup expired ───────────────────────────────────────────────────────────
|
||||
async def cleanup_expired():
|
||||
now = datetime.utcnow()
|
||||
db = SessionLocal()
|
||||
expired = db.query(Subscription)\
|
||||
rows = db.query(Subscription)\
|
||||
.filter(Subscription.status=="active", Subscription.expires_at <= now)\
|
||||
.all()
|
||||
for sub in expired:
|
||||
for sub in rows:
|
||||
guild = bot.get_guild(GUILD_ID)
|
||||
member = guild.get_member(int(sub.user_id))
|
||||
role = guild.get_role(int(sub.plan.role_id))
|
||||
@ -276,42 +273,35 @@ async def cleanup_expired():
|
||||
sub.status = "expired"
|
||||
sub.role_removed_at = now
|
||||
db.commit()
|
||||
logger.info(f"Removed {role.name} from {member.display_name}")
|
||||
logger.info(f"Removed expired {role.name} from {member.display_name}")
|
||||
except Exception:
|
||||
logger.exception("Failed to remove expired role")
|
||||
logger.exception("Error removing expired role")
|
||||
db.close()
|
||||
|
||||
# ─── on_ready: dynamic commands + start jobs ──────────────────────────────────
|
||||
# ─── on_ready & scheduler ─────────────────────────────────────────────────────
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})")
|
||||
logger.info(f"Bot logged in as {bot.user} (ID {bot.user.id})")
|
||||
|
||||
# Register commands for each plan
|
||||
# dynamic command registration
|
||||
db = SessionLocal()
|
||||
plans = db.query(Plan).all()
|
||||
for plan in plans:
|
||||
@bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}")
|
||||
@bot.tree.command(name=plan.command_name,
|
||||
description=f"Buy {plan.name} for {plan.price_sats} sats")
|
||||
async def _cmd(interaction: discord.Interaction, plan=plan):
|
||||
await handle_plan_purchase(interaction, plan)
|
||||
db.close()
|
||||
await bot.tree.sync()
|
||||
|
||||
# Start WS listener
|
||||
# start WS listener
|
||||
bot.loop.create_task(lnbits_ws_listener())
|
||||
|
||||
# Schedule cleanup daily at midnight UTC
|
||||
scheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.add_job(cleanup_expired, "cron", hour=0, minute=0)
|
||||
scheduler.start()
|
||||
# schedule midnight‐UTC expiry check
|
||||
sched = AsyncIOScheduler(timezone="UTC")
|
||||
sched.add_job(cleanup_expired, "cron", hour=0, minute=0)
|
||||
sched.start()
|
||||
|
||||
# ─── Entry point ───────────────────────────────────────────────────────────────
|
||||
# ─── Run ───────────────────────────────────────────────────────────────────────
|
||||
if __name__ == "__main__":
|
||||
# Seed Config table with required keys if missing
|
||||
session = SessionLocal()
|
||||
for key in ("discord_token","guild_id","lnbits_url","lnbits_api_key"):
|
||||
if not session.get(Config, key):
|
||||
session.add(Config(key=key, value=""))
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
bot.run(DISCORD_TOKEN)
|
||||
|
Loading…
x
Reference in New Issue
Block a user