Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions lnvps_api/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,19 @@ impl Worker {
// Only proceed with deletion if the VM is still in the unpaid (new) state.
match self.db.get_vm(vm.id).await {
Ok(current_vm) if current_vm.created == current_vm.expires => {
if self
.db
.count_active_vm_payments(vm.id)
.await
.unwrap_or(0)
> 0
{
info!(
"VM {} has pending unpaid payments, skipping deletion",
vm.id
);
continue;
}
info!("Deleting unpaid VM {}", vm.id);
if let Err(e) = self.provisioner.delete_vm(vm.id).await {
error!("Failed to delete unpaid VM {}: {}", vm.id, e);
Expand Down Expand Up @@ -2472,4 +2485,96 @@ mod tests {
);
Ok(())
}

/// An unpaid VM (new state, older than 1 hour) with a non-expired pending payment must NOT
/// be deleted by check_vms.
#[tokio::test]
async fn test_check_vms_skips_unpaid_vm_with_pending_payment() -> Result<()> {
use lnvps_db::{EncryptedString, PaymentMethod, PaymentType, VmPayment};

let db = Arc::new(MockDb::default());
let old = Utc::now().sub(TimeDelta::hours(2));
let vm = add_vm_with_state(&db, old, old).await?;
let vm_id = vm.id;

// Add a pending (unpaid, not-yet-expired) payment for this VM.
let payment = VmPayment {
id: vec![1u8; 32],
vm_id,
created: Utc::now(),
expires: Utc::now().add(TimeDelta::minutes(10)),
amount: 1000,
currency: "BTC".to_string(),
payment_method: PaymentMethod::Lightning,
payment_type: PaymentType::Renewal,
external_data: EncryptedString::from("test"),
external_id: None,
is_paid: false,
rate: 1.0,
time_value: 2592000,
tax: 0,
processing_fee: 0,
upgrade_params: None,
paid_at: None,
};
db.insert_vm_payment(&payment).await?;

let worker = setup_worker(db.clone()).await?;
worker.check_vms().await?;

// VM must NOT be deleted because there is a pending payment.
let vms = db.vms.lock().await;
let deleted = vms.get(&vm_id).map(|v| v.deleted).unwrap_or(true);
assert!(
!deleted,
"Unpaid VM with a non-expired pending payment should not be deleted"
);
Ok(())
}

/// An unpaid VM (new state, older than 1 hour) whose only payment is already expired must
/// still be deleted by check_vms.
#[tokio::test]
async fn test_check_vms_deletes_unpaid_vm_with_only_expired_payment() -> Result<()> {
use lnvps_db::{EncryptedString, PaymentMethod, PaymentType, VmPayment};

let db = Arc::new(MockDb::default());
let old = Utc::now().sub(TimeDelta::hours(2));
let vm = add_vm_with_state(&db, old, old).await?;
let vm_id = vm.id;

// Add a payment whose invoice has already expired.
let payment = VmPayment {
id: vec![2u8; 32],
vm_id,
created: old,
expires: old.add(TimeDelta::minutes(10)), // expired long ago
amount: 1000,
currency: "BTC".to_string(),
payment_method: PaymentMethod::Lightning,
payment_type: PaymentType::Renewal,
external_data: EncryptedString::from("test"),
external_id: None,
is_paid: false,
rate: 1.0,
time_value: 2592000,
tax: 0,
processing_fee: 0,
upgrade_params: None,
paid_at: None,
};
db.insert_vm_payment(&payment).await?;

let worker = setup_worker(db.clone()).await?;
worker.check_vms().await?;

// VM should be soft-deleted because the only payment is expired.
let vms = db.vms.lock().await;
let deleted = vms.get(&vm_id).map(|v| v.deleted).unwrap_or(false);
assert!(
deleted,
"Unpaid VM with only an expired payment should still be deleted"
);
Ok(())
}
}
7 changes: 7 additions & 0 deletions lnvps_api_common/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,13 @@ impl LNVpsDbBase for MockDb {
.cloned())
}

async fn count_active_vm_payments(&self, vm_id: u64) -> DbResult<u64> {
let p = self.payments.lock().await;
Ok(p.iter()
.filter(|p| p.vm_id == vm_id && !p.is_paid && p.expires > Utc::now())
.count() as u64)
}

async fn list_custom_pricing(&self, _TB: u64) -> DbResult<Vec<VmCustomPricing>> {
let p = self.custom_pricing.lock().await;
Ok(p.values().cloned().collect())
Expand Down
3 changes: 3 additions & 0 deletions lnvps_db/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ pub trait LNVpsDbBase: Send + Sync {
/// Return the most recently settled invoice
async fn last_paid_invoice(&self) -> DbResult<Option<VmPayment>>;

/// Count active (unpaid, non-expired) payments for a VM
async fn count_active_vm_payments(&self, vm_id: u64) -> DbResult<u64>;

/// Return the list of active custom pricing models for a given region
async fn list_custom_pricing(&self, region_id: u64) -> DbResult<Vec<VmCustomPricing>>;

Expand Down
9 changes: 9 additions & 0 deletions lnvps_db/src/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,15 @@ impl LNVpsDbBase for LNVpsDbMysql {
.await?)
}

async fn count_active_vm_payments(&self, vm_id: u64) -> DbResult<u64> {
let (count,): (i64,) =
sqlx::query_as("select count(*) from vm_payment where vm_id = ? and is_paid = false and expires > NOW()")
.bind(vm_id)
.fetch_one(&self.db)
.await?;
Ok(count as u64)
}

async fn list_custom_pricing(&self, region_id: u64) -> DbResult<Vec<VmCustomPricing>> {
Ok(
sqlx::query_as("select * from vm_custom_pricing where region_id = ?")
Expand Down
Loading