diff --git a/lnvps_api/src/worker.rs b/lnvps_api/src/worker.rs index 6b82fa9..e86ca02 100644 --- a/lnvps_api/src/worker.rs +++ b/lnvps_api/src/worker.rs @@ -492,6 +492,19 @@ impl Worker { // Only proceed with deletion if the VM is still in the unpaid (new) state. match self.db.get_vm(vm.id).await { Ok(current_vm) if current_vm.created == current_vm.expires => { + if self + .db + .count_active_vm_payments(vm.id) + .await + .unwrap_or(0) + > 0 + { + info!( + "VM {} has pending unpaid payments, skipping deletion", + vm.id + ); + continue; + } info!("Deleting unpaid VM {}", vm.id); if let Err(e) = self.provisioner.delete_vm(vm.id).await { error!("Failed to delete unpaid VM {}: {}", vm.id, e); @@ -2472,4 +2485,96 @@ mod tests { ); Ok(()) } + + /// An unpaid VM (new state, older than 1 hour) with a non-expired pending payment must NOT + /// be deleted by check_vms. + #[tokio::test] + async fn test_check_vms_skips_unpaid_vm_with_pending_payment() -> Result<()> { + use lnvps_db::{EncryptedString, PaymentMethod, PaymentType, VmPayment}; + + let db = Arc::new(MockDb::default()); + let old = Utc::now().sub(TimeDelta::hours(2)); + let vm = add_vm_with_state(&db, old, old).await?; + let vm_id = vm.id; + + // Add a pending (unpaid, not-yet-expired) payment for this VM. + let payment = VmPayment { + id: vec![1u8; 32], + vm_id, + created: Utc::now(), + expires: Utc::now().add(TimeDelta::minutes(10)), + amount: 1000, + currency: "BTC".to_string(), + payment_method: PaymentMethod::Lightning, + payment_type: PaymentType::Renewal, + external_data: EncryptedString::from("test"), + external_id: None, + is_paid: false, + rate: 1.0, + time_value: 2592000, + tax: 0, + processing_fee: 0, + upgrade_params: None, + paid_at: None, + }; + db.insert_vm_payment(&payment).await?; + + let worker = setup_worker(db.clone()).await?; + worker.check_vms().await?; + + // VM must NOT be deleted because there is a pending payment. + let vms = db.vms.lock().await; + let deleted = vms.get(&vm_id).map(|v| v.deleted).unwrap_or(true); + assert!( + !deleted, + "Unpaid VM with a non-expired pending payment should not be deleted" + ); + Ok(()) + } + + /// An unpaid VM (new state, older than 1 hour) whose only payment is already expired must + /// still be deleted by check_vms. + #[tokio::test] + async fn test_check_vms_deletes_unpaid_vm_with_only_expired_payment() -> Result<()> { + use lnvps_db::{EncryptedString, PaymentMethod, PaymentType, VmPayment}; + + let db = Arc::new(MockDb::default()); + let old = Utc::now().sub(TimeDelta::hours(2)); + let vm = add_vm_with_state(&db, old, old).await?; + let vm_id = vm.id; + + // Add a payment whose invoice has already expired. + let payment = VmPayment { + id: vec![2u8; 32], + vm_id, + created: old, + expires: old.add(TimeDelta::minutes(10)), // expired long ago + amount: 1000, + currency: "BTC".to_string(), + payment_method: PaymentMethod::Lightning, + payment_type: PaymentType::Renewal, + external_data: EncryptedString::from("test"), + external_id: None, + is_paid: false, + rate: 1.0, + time_value: 2592000, + tax: 0, + processing_fee: 0, + upgrade_params: None, + paid_at: None, + }; + db.insert_vm_payment(&payment).await?; + + let worker = setup_worker(db.clone()).await?; + worker.check_vms().await?; + + // VM should be soft-deleted because the only payment is expired. + let vms = db.vms.lock().await; + let deleted = vms.get(&vm_id).map(|v| v.deleted).unwrap_or(false); + assert!( + deleted, + "Unpaid VM with only an expired payment should still be deleted" + ); + Ok(()) + } } diff --git a/lnvps_api_common/src/mock.rs b/lnvps_api_common/src/mock.rs index 471555e..92a16c1 100644 --- a/lnvps_api_common/src/mock.rs +++ b/lnvps_api_common/src/mock.rs @@ -852,6 +852,13 @@ impl LNVpsDbBase for MockDb { .cloned()) } + async fn count_active_vm_payments(&self, vm_id: u64) -> DbResult { + let p = self.payments.lock().await; + Ok(p.iter() + .filter(|p| p.vm_id == vm_id && !p.is_paid && p.expires > Utc::now()) + .count() as u64) + } + async fn list_custom_pricing(&self, _TB: u64) -> DbResult> { let p = self.custom_pricing.lock().await; Ok(p.values().cloned().collect()) diff --git a/lnvps_db/src/lib.rs b/lnvps_db/src/lib.rs index 4ebc226..891f029 100644 --- a/lnvps_db/src/lib.rs +++ b/lnvps_db/src/lib.rs @@ -286,6 +286,9 @@ pub trait LNVpsDbBase: Send + Sync { /// Return the most recently settled invoice async fn last_paid_invoice(&self) -> DbResult>; + /// Count active (unpaid, non-expired) payments for a VM + async fn count_active_vm_payments(&self, vm_id: u64) -> DbResult; + /// Return the list of active custom pricing models for a given region async fn list_custom_pricing(&self, region_id: u64) -> DbResult>; diff --git a/lnvps_db/src/mysql.rs b/lnvps_db/src/mysql.rs index fb8c49f..ccf624a 100644 --- a/lnvps_db/src/mysql.rs +++ b/lnvps_db/src/mysql.rs @@ -821,6 +821,15 @@ impl LNVpsDbBase for LNVpsDbMysql { .await?) } + async fn count_active_vm_payments(&self, vm_id: u64) -> DbResult { + let (count,): (i64,) = + sqlx::query_as("select count(*) from vm_payment where vm_id = ? and is_paid = false and expires > NOW()") + .bind(vm_id) + .fetch_one(&self.db) + .await?; + Ok(count as u64) + } + async fn list_custom_pricing(&self, region_id: u64) -> DbResult> { Ok( sqlx::query_as("select * from vm_custom_pricing where region_id = ?")